Coverage for src/bartz/debug.py: 50%

113 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-06-28 20:44 +0000

1import functools 1a

2 

3import jax 1a

4from jax import numpy as jnp 1a

5from jax import lax 1a

6 

7from . import grove 1a

8from . import mcmcstep 1a

9from . import jaxext 1a

10 

11def print_tree(leaf_tree, var_tree, split_tree, print_all=False): 1a

12 

13 tee = '├──' 

14 corner = '└──' 

15 join = '│ ' 

16 space = ' ' 

17 down = '┐' 

18 bottom = '╢' # '┨' #  

19 

20 def traverse_tree(index, depth, indent, first_indent, next_indent, unused): 

21 if index >= len(leaf_tree): 

22 return 

23 

24 var = var_tree.at[index].get(mode='fill', fill_value=0) 

25 split = split_tree.at[index].get(mode='fill', fill_value=0) 

26 

27 is_leaf = split == 0 

28 left_child = 2 * index 

29 right_child = 2 * index + 1 

30 

31 if print_all: 

32 if unused: 

33 category = 'unused' 

34 elif is_leaf: 

35 category = 'leaf' 

36 else: 

37 category = 'decision' 

38 node_str = f'{category}({var}, {split}, {leaf_tree[index]})' 

39 else: 

40 assert not unused 

41 if is_leaf: 

42 node_str = f'{leaf_tree[index]:#.2g}' 

43 else: 

44 node_str = f'({var}: {split})' 

45 

46 if not is_leaf or (print_all and left_child < len(leaf_tree)): 

47 link = down 

48 elif not print_all and left_child >= len(leaf_tree): 

49 link = bottom 

50 else: 

51 link = ' ' 

52 

53 max_number = len(leaf_tree) - 1 

54 ndigits = len(str(max_number)) 

55 number = str(index).rjust(ndigits) 

56 

57 print(f' {number} {indent}{first_indent}{link}{node_str}') 

58 

59 indent += next_indent 

60 unused = unused or is_leaf 

61 

62 if unused and not print_all: 

63 return 

64 

65 traverse_tree(left_child, depth + 1, indent, tee, join, unused) 

66 traverse_tree(right_child, depth + 1, indent, corner, space, unused) 

67 

68 traverse_tree(1, 0, '', '', '', False) 

69 

70def tree_actual_depth(split_tree): 1a

71 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True) 

72 depth = grove.tree_depths(is_leaf.size) 

73 depth = jnp.where(is_leaf, depth, 0) 

74 return jnp.max(depth) 

75 

76def forest_depth_distr(split_trees): 1a

77 depth = grove.tree_depth(split_trees) + 1 

78 depths = jax.vmap(tree_actual_depth)(split_trees) 

79 return jnp.bincount(depths, length=depth) 

80 

81def trace_depth_distr(split_trees_trace): 1a

82 return jax.vmap(forest_depth_distr)(split_trees_trace) 

83 

84def points_per_leaf_distr(var_tree, split_tree, X): 1a

85 traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None)) 1a

86 indices = traverse_tree(X, var_tree, split_tree) 1a

87 count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)) 1a

88 count_tree = count_tree.at[indices].add(1) 1a

89 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8) 1a

90 return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1) 1a

91 

92def forest_points_per_leaf_distr(bart, X): 1a

93 distr = jnp.zeros(X.shape[1] + 1, int) 1a

94 trees = bart['var_trees'], bart['split_trees'] 1a

95 def loop(distr, tree): 1a

96 return distr + points_per_leaf_distr(*tree, X), None 1a

97 distr, _ = lax.scan(loop, distr, trees) 1a

98 return distr 1a

99 

100def trace_points_per_leaf_distr(bart, X): 1a

101 def loop(_, bart): 1a

102 return None, forest_points_per_leaf_distr(bart, X) 1a

103 _, distr = lax.scan(loop, None, bart) 1a

104 return distr 1a

105 

106def check_types(leaf_tree, var_tree, split_tree, max_split): 1a

107 expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1) 1a

108 expected_split_dtype = max_split.dtype 1a

109 return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype 1a

110 

111def check_sizes(leaf_tree, var_tree, split_tree, max_split): 1a

112 return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size 1a

113 

114def check_unused_node(leaf_tree, var_tree, split_tree, max_split): 1a

115 return (var_tree[0] == 0) & (split_tree[0] == 0) 1a

116 

117def check_leaf_values(leaf_tree, var_tree, split_tree, max_split): 1a

118 return jnp.all(jnp.isfinite(leaf_tree)) 1a

119 

120def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split): 1a

121 index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)) 1a

122 parent_index = index >> 1 1a

123 is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0 1a

124 parent_is_leaf = split_tree[parent_index] == 0 1a

125 stray = is_not_leaf & parent_is_leaf 1a

126 stray = stray.at[1].set(False) 1a

127 return ~jnp.any(stray) 1a

128 

129check_functions = [ 1a

130 check_types, 

131 check_sizes, 

132 check_unused_node, 

133 check_leaf_values, 

134 check_stray_nodes, 

135] 

136 

137def check_tree(leaf_tree, var_tree, split_tree, max_split): 1a

138 error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1a

139 error = error_type(0) 1a

140 for i, func in enumerate(check_functions): 1a

141 ok = func(leaf_tree, var_tree, split_tree, max_split) 1a

142 ok = jnp.bool_(ok) 1a

143 bit = (~ok) << i 1a

144 error |= bit 1a

145 return error 1a

146 

147def describe_error(error): 1a

148 return [ 

149 func.__name__ 

150 for i, func in enumerate(check_functions) 

151 if error & (1 << i) 

152 ] 

153 

154check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None)) 1a

155 

156@functools.partial(jax.vmap, in_axes=(0, None)) 1a

157def check_trace(trace, state): 1a

158 return check_forest(trace['leaf_trees'], trace['var_trees'], trace['split_trees'], state['max_split']) 1a