Coverage for src/bartz/debug.py: 49%

111 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1import functools 1ab

2 

3import jax 1ab

4from jax import lax 1ab

5from jax import numpy as jnp 1ab

6 

7from . import grove, jaxext 1ab

8 

9 

10def print_tree(leaf_tree, var_tree, split_tree, print_all=False): 1ab

11 tee = '├──' 

12 corner = '└──' 

13 join = '│ ' 

14 space = ' ' 

15 down = '┐' 

16 bottom = '╢' # '┨' # 

17 

18 def traverse_tree(index, depth, indent, first_indent, next_indent, unused): 

19 if index >= len(leaf_tree): 

20 return 

21 

22 var = var_tree.at[index].get(mode='fill', fill_value=0) 

23 split = split_tree.at[index].get(mode='fill', fill_value=0) 

24 

25 is_leaf = split == 0 

26 left_child = 2 * index 

27 right_child = 2 * index + 1 

28 

29 if print_all: 

30 if unused: 

31 category = 'unused' 

32 elif is_leaf: 

33 category = 'leaf' 

34 else: 

35 category = 'decision' 

36 node_str = f'{category}({var}, {split}, {leaf_tree[index]})' 

37 else: 

38 assert not unused 

39 if is_leaf: 

40 node_str = f'{leaf_tree[index]:#.2g}' 

41 else: 

42 node_str = f'({var}: {split})' 

43 

44 if not is_leaf or (print_all and left_child < len(leaf_tree)): 

45 link = down 

46 elif not print_all and left_child >= len(leaf_tree): 

47 link = bottom 

48 else: 

49 link = ' ' 

50 

51 max_number = len(leaf_tree) - 1 

52 ndigits = len(str(max_number)) 

53 number = str(index).rjust(ndigits) 

54 

55 print(f' {number} {indent}{first_indent}{link}{node_str}') 

56 

57 indent += next_indent 

58 unused = unused or is_leaf 

59 

60 if unused and not print_all: 

61 return 

62 

63 traverse_tree(left_child, depth + 1, indent, tee, join, unused) 

64 traverse_tree(right_child, depth + 1, indent, corner, space, unused) 

65 

66 traverse_tree(1, 0, '', '', '', False) 

67 

68 

69def tree_actual_depth(split_tree): 1ab

70 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True) 

71 depth = grove.tree_depths(is_leaf.size) 

72 depth = jnp.where(is_leaf, depth, 0) 

73 return jnp.max(depth) 

74 

75 

76def forest_depth_distr(split_trees): 1ab

77 depth = grove.tree_depth(split_trees) + 1 

78 depths = jax.vmap(tree_actual_depth)(split_trees) 

79 return jnp.bincount(depths, length=depth) 

80 

81 

82def trace_depth_distr(split_trees_trace): 1ab

83 return jax.vmap(forest_depth_distr)(split_trees_trace) 

84 

85 

86def points_per_leaf_distr(var_tree, split_tree, X): 1ab

87 traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None)) 1ab

88 indices = traverse_tree(X, var_tree, split_tree) 1ab

89 count_tree = jnp.zeros( 1ab

90 2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size) 

91 ) 

92 count_tree = count_tree.at[indices].add(1) 1ab

93 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8) 1ab

94 return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1) 1ab

95 

96 

97def forest_points_per_leaf_distr(bart, X): 1ab

98 distr = jnp.zeros(X.shape[1] + 1, int) 1ab

99 trees = bart['var_trees'], bart['split_trees'] 1ab

100 

101 def loop(distr, tree): 1ab

102 return distr + points_per_leaf_distr(*tree, X), None 1ab

103 

104 distr, _ = lax.scan(loop, distr, trees) 1ab

105 return distr 1ab

106 

107 

108def trace_points_per_leaf_distr(bart, X): 1ab

109 def loop(_, bart): 1ab

110 return None, forest_points_per_leaf_distr(bart, X) 1ab

111 

112 _, distr = lax.scan(loop, None, bart) 1ab

113 return distr 1ab

114 

115 

116def check_types(leaf_tree, var_tree, split_tree, max_split): 1ab

117 expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1) 1ab

118 expected_split_dtype = max_split.dtype 1ab

119 return ( 1ab

120 var_tree.dtype == expected_var_dtype 

121 and split_tree.dtype == expected_split_dtype 

122 ) 

123 

124 

125def check_sizes(leaf_tree, var_tree, split_tree, max_split): 1ab

126 return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size 1ab

127 

128 

129def check_unused_node(leaf_tree, var_tree, split_tree, max_split): 1ab

130 return (var_tree[0] == 0) & (split_tree[0] == 0) 1ab

131 

132 

133def check_leaf_values(leaf_tree, var_tree, split_tree, max_split): 1ab

134 return jnp.all(jnp.isfinite(leaf_tree)) 1ab

135 

136 

137def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split): 1ab

138 index = jnp.arange( 1ab

139 2 * split_tree.size, 

140 dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1), 

141 ) 

142 parent_index = index >> 1 1ab

143 is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab

144 parent_is_leaf = split_tree[parent_index] == 0 1ab

145 stray = is_not_leaf & parent_is_leaf 1ab

146 stray = stray.at[1].set(False) 1ab

147 return ~jnp.any(stray) 1ab

148 

149 

150check_functions = [ 1ab

151 check_types, 

152 check_sizes, 

153 check_unused_node, 

154 check_leaf_values, 

155 check_stray_nodes, 

156] 

157 

158 

159def check_tree(leaf_tree, var_tree, split_tree, max_split): 1ab

160 error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1ab

161 error = error_type(0) 1ab

162 for i, func in enumerate(check_functions): 1ab

163 ok = func(leaf_tree, var_tree, split_tree, max_split) 1ab

164 ok = jnp.bool_(ok) 1ab

165 bit = (~ok) << i 1ab

166 error |= bit 1ab

167 return error 1ab

168 

169 

170def describe_error(error): 1ab

171 return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)] 

172 

173 

174check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None)) 1ab

175 

176 

177@functools.partial(jax.vmap, in_axes=(0, None)) 1ab

178def check_trace(trace, state): 1ab

179 return check_forest( 1ab

180 trace['leaf_trees'], 

181 trace['var_trees'], 

182 trace['split_trees'], 

183 state.max_split, 

184 )