Coverage for src/bartz/debug.py: 86%

427 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/debug.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The entry point is the class `debug_gbart`.""" 

26 

27from collections.abc import Callable 1ab

28from dataclasses import replace 1ab

29from functools import partial 1ab

30from math import ceil, log2 1ab

31from re import fullmatch 1ab

32 

33import numpy 1ab

34from equinox import Module, field 1ab

35from jax import jit, lax, random, vmap 1ab

36from jax import numpy as jnp 1ab

37from jax.tree_util import tree_map 1ab

38from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt 1ab

39 

40from bartz.BART import FloatLike, gbart 1ab

41from bartz.grove import ( 1ab

42 TreeHeaps, 

43 evaluate_forest, 

44 is_actual_leaf, 

45 is_leaves_parent, 

46 traverse_tree, 

47 tree_depth, 

48 tree_depths, 

49) 

50from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 1ab

51from bartz.jaxext import split as split_key 1ab

52from bartz.mcmcloop import TreesTrace 1ab

53from bartz.mcmcstep import randint_masked 1ab

54 

55 

56def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 1ab

57 """Convert a tree to a human-readable string. 

58 

59 Parameters 

60 ---------- 

61 tree 

62 A single tree to format. 

63 print_all 

64 If `True`, also print the contents of unused node slots in the arrays. 

65 

66 Returns 

67 ------- 

68 A string representation of the tree. 

69 """ 

70 tee = '├──' 1ab

71 corner = '└──' 1ab

72 join = '│ ' 1ab

73 space = ' ' 1ab

74 down = '┐' 1ab

75 bottom = '╢' # '┨' # 1ab

76 

77 def traverse_tree( 1ab

78 lines: list[str], 

79 index: int, 

80 depth: int, 

81 indent: str, 

82 first_indent: str, 

83 next_indent: str, 

84 unused: bool, 

85 ): 

86 if index >= len(tree.leaf_tree): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1ab

87 return 

88 

89 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1ab

90 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1ab

91 

92 is_leaf = split == 0 1ab

93 left_child = 2 * index 1ab

94 right_child = 2 * index + 1 1ab

95 

96 if print_all: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true1ab

97 if unused: 

98 category = 'unused' 

99 elif is_leaf: 

100 category = 'leaf' 

101 else: 

102 category = 'decision' 

103 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})' 

104 else: 

105 assert not unused 1ab

106 if is_leaf: 1ab

107 node_str = f'{tree.leaf_tree[index]:#.2g}' 1ab

108 else: 

109 node_str = f'x{var} < {split}' 1ab

110 

111 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1ab

112 link = down 1ab

113 elif not print_all and left_child >= len(tree.leaf_tree): 1ab

114 link = bottom 1ab

115 else: 

116 link = ' ' 1ab

117 

118 max_number = len(tree.leaf_tree) - 1 1ab

119 ndigits = len(str(max_number)) 1ab

120 number = str(index).rjust(ndigits) 1ab

121 

122 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1ab

123 

124 indent += next_indent 1ab

125 unused = unused or is_leaf 1ab

126 

127 if unused and not print_all: 1ab

128 return 1ab

129 

130 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1ab

131 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1ab

132 

133 lines = [] 1ab

134 traverse_tree(lines, 1, 0, '', '', '', False) 1ab

135 return '\n'.join(lines) 1ab

136 

137 

138def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: 1ab

139 """Measure the depth of the tree. 

140 

141 Parameters 

142 ---------- 

143 split_tree 

144 The cutpoints of the decision rules. 

145 

146 Returns 

147 ------- 

148 The depth of the deepest leaf in the tree. The root is at depth 0. 

149 """ 

150 # this could be done just with split_tree != 0 

151 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab

152 depth = tree_depths(is_leaf.size) 1ab

153 depth = jnp.where(is_leaf, depth, 0) 1ab

154 return jnp.max(depth) 1ab

155 

156 

157def forest_depth_distr( 1ab

158 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

159) -> Int32[Array, ' d']: 

160 """Histogram the depths of a set of trees. 

161 

162 Parameters 

163 ---------- 

164 split_tree 

165 The cutpoints of the decision rules of the trees. 

166 

167 Returns 

168 ------- 

169 An integer vector where the i-th element counts how many trees have depth i. 

170 """ 

171 depth = tree_depth(split_tree) + 1 1ab

172 depths = vmap(tree_actual_depth)(split_tree) 1ab

173 return jnp.bincount(depths, length=depth) 1ab

174 

175 

176@jit 1ab

177def trace_depth_distr( 1ab

178 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'], 

179) -> Int32[Array, 'trace_length d']: 

180 """Histogram the depths of a sequence of sets of trees. 

181 

182 Parameters 

183 ---------- 

184 split_tree 

185 The cutpoints of the decision rules of the trees. 

186 

187 Returns 

188 ------- 

189 A matrix where element (t,i) counts how many trees have depth i in set t. 

190 """ 

191 return vmap(forest_depth_distr)(split_tree) 1ab

192 

193 

194def points_per_decision_node_distr( 1ab

195 var_tree: UInt[Array, ' 2**(d-1)'], 

196 split_tree: UInt[Array, ' 2**(d-1)'], 

197 X: UInt[Array, 'p n'], 

198) -> Int32[Array, ' n+1']: 

199 """Histogram points-per-node counts. 

200 

201 Count how many parent-of-leaf nodes in a tree select each possible amount 

202 of points. 

203 

204 Parameters 

205 ---------- 

206 var_tree 

207 The variables of the decision rules. 

208 split_tree 

209 The cutpoints of the decision rules. 

210 X 

211 The set of points to count. 

212 

213 Returns 

214 ------- 

215 A vector where the i-th element counts how many next-to-leaf nodes have i points. 

216 """ 

217 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab

218 indices = traverse_tree_X(X, var_tree, split_tree) 1ab

219 indices >>= 1 1ab

220 count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0) 1ab

221 is_parent = is_leaves_parent(split_tree) 1ab

222 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent) 1ab

223 

224 

225def forest_points_per_decision_node_distr( 1ab

226 trees: TreeHeaps, X: UInt[Array, 'p n'] 

227) -> Int32[Array, ' n+1']: 

228 """Histogram points-per-node counts for a set of trees. 

229 

230 Count how many parent-of-leaf nodes in a set of trees select each possible 

231 amount of points. 

232 

233 Parameters 

234 ---------- 

235 trees 

236 The set of trees. The variables must have broadcast shape (num_trees,). 

237 X 

238 The set of points to count. 

239 

240 Returns 

241 ------- 

242 A vector where the i-th element counts how many next-to-leaf nodes have i points. 

243 """ 

244 distr = jnp.zeros(X.shape[1] + 1, int) 1ab

245 

246 def loop(distr, heaps: tuple[Array, Array]): 1ab

247 return distr + points_per_decision_node_distr(*heaps, X), None 1ab

248 

249 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab

250 return distr 1ab

251 

252 

253@jit 1ab

254def trace_points_per_decision_node_distr( 1ab

255 trace: TreeHeaps, X: UInt[Array, 'p n'] 

256) -> Int32[Array, 'trace_length n+1']: 

257 """Separately histogram points-per-node counts over a sequence of sets of trees. 

258 

259 For each set of trees, count how many parent-of-leaf nodes select each 

260 possible amount of points. 

261 

262 Parameters 

263 ---------- 

264 trace 

265 The sequence of sets of trees. The variables must have broadcast shape 

266 (trace_length, num_trees). 

267 X 

268 The set of points to count. 

269 

270 Returns 

271 ------- 

272 A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t. 

273 """ 

274 

275 def loop(_, trace): 1ab

276 return None, forest_points_per_decision_node_distr(trace, X) 1ab

277 

278 _, distr = lax.scan(loop, None, trace) 1ab

279 return distr 1ab

280 

281 

282def points_per_leaf_distr( 1ab

283 var_tree: UInt[Array, ' 2**(d-1)'], 

284 split_tree: UInt[Array, ' 2**(d-1)'], 

285 X: UInt[Array, 'p n'], 

286) -> Int32[Array, ' n+1']: 

287 """Histogram points-per-leaf counts in a tree. 

288 

289 Count how many leaves in a tree select each possible amount of points. 

290 

291 Parameters 

292 ---------- 

293 var_tree 

294 The variables of the decision rules. 

295 split_tree 

296 The cutpoints of the decision rules. 

297 X 

298 The set of points to count. 

299 

300 Returns 

301 ------- 

302 A vector where the i-th element counts how many leaves have i points. 

303 """ 

304 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab

305 indices = traverse_tree_X(X, var_tree, split_tree) 1ab

306 count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1) 1ab

307 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab

308 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf) 1ab

309 

310 

311def forest_points_per_leaf_distr( 1ab

312 trees: TreeHeaps, X: UInt[Array, 'p n'] 

313) -> Int32[Array, ' n+1']: 

314 """Histogram points-per-leaf counts over a set of trees. 

315 

316 Count how many leaves in a set of trees select each possible amount of points. 

317 

318 Parameters 

319 ---------- 

320 trees 

321 The set of trees. The variables must have broadcast shape (num_trees,). 

322 X 

323 The set of points to count. 

324 

325 Returns 

326 ------- 

327 A vector where the i-th element counts how many leaves have i points. 

328 """ 

329 distr = jnp.zeros(X.shape[1] + 1, int) 1ab

330 

331 def loop(distr, heaps: tuple[Array, Array]): 1ab

332 return distr + points_per_leaf_distr(*heaps, X), None 1ab

333 

334 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab

335 return distr 1ab

336 

337 

338@jit 1ab

339def trace_points_per_leaf_distr( 1ab

340 trace: TreeHeaps, X: UInt[Array, 'p n'] 

341) -> Int32[Array, 'trace_length n+1']: 

342 """Separately histogram points-per-leaf counts over a sequence of sets of trees. 

343 

344 For each set of trees, count how many leaves select each possible amount of 

345 points. 

346 

347 Parameters 

348 ---------- 

349 trace 

350 The sequence of sets of trees. The variables must have broadcast shape 

351 (trace_length, num_trees). 

352 X 

353 The set of points to count. 

354 

355 Returns 

356 ------- 

357 A matrix where element (t,i) counts how many leaves have i points in set t. 

358 """ 

359 

360 def loop(_, trace): 1ab

361 return None, forest_points_per_leaf_distr(trace, X) 1ab

362 

363 _, distr = lax.scan(loop, None, trace) 1ab

364 return distr 1ab

365 

366 

367check_functions = [] 1ab

368 

369 

370CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']] 1ab

371 

372 

373def check(func: CheckFunc) -> CheckFunc: 1ab

374 """Add a function to a list of functions used to check trees. 

375 

376 Use to decorate functions that check whether a tree is valid in some way. 

377 These functions are invoked automatically by `check_tree`, `check_trace` and 

378 `debug_gbart`. 

379 

380 Parameters 

381 ---------- 

382 func 

383 The function to add to the list. It must accept a `TreeHeaps` and a 

384 `max_split` argument, and return a boolean scalar that indicates if the 

385 tree is ok. 

386 

387 Returns 

388 ------- 

389 The function unchanged. 

390 """ 

391 check_functions.append(func) 1ab

392 return func 1ab

393 

394 

395@check 1ab

396def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 1ab

397 """Check that integer types are as small as possible and coherent.""" 

398 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab

399 expected_split_dtype = max_split.dtype 1ab

400 return ( 1ab

401 tree.var_tree.dtype == expected_var_dtype 

402 and tree.split_tree.dtype == expected_split_dtype 

403 ) 

404 

405 

406@check 1ab

407def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001 1ab

408 """Check that array sizes are coherent.""" 

409 return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size 1ab

410 

411 

412@check 1ab

413def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

414 """Check that the unused node slot at index 0 is not dirty.""" 

415 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1ab

416 

417 

418@check 1ab

419def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

420 """Check that all leaf values are not inf of nan.""" 

421 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1ab

422 

423 

424@check 1ab

425def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

426 """Check if there is any marked-non-leaf node with a marked-leaf parent.""" 

427 index = jnp.arange( 1ab

428 2 * tree.split_tree.size, 

429 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), 

430 ) 

431 parent_index = index >> 1 1ab

432 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab

433 parent_is_leaf = tree.split_tree[parent_index] == 0 1ab

434 stray = is_not_leaf & parent_is_leaf 1ab

435 stray = stray.at[1].set(False) 1ab

436 return ~jnp.any(stray) 1ab

437 

438 

439@check 1ab

440def check_rule_consistency( 1ab

441 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

442) -> bool | Bool[Array, '']: 

443 """Check that decision rules define proper subsets of ancestor rules.""" 

444 if tree.var_tree.size < 4: 444 ↛ 445line 444 didn't jump to line 445 because the condition on line 444 was never true1ab

445 return True 

446 

447 # initial boundaries of decision rules. use extreme integers instead of 0, 

448 # max_split to avoid checking if there is something out of bounds. 

449 small = jnp.iinfo(jnp.int32).min 1ab

450 large = jnp.iinfo(jnp.int32).max 1ab

451 lower = jnp.full(max_split.size, small, jnp.int32) 1ab

452 upper = jnp.full(max_split.size, large, jnp.int32) 1ab

453 # specify the type explicitly, otherwise they are weakly types and get 

454 # implicitly converted to split.dtype (typically uint8) in the expressions 

455 

456 def _check_recursive(node, lower, upper): 1ab

457 # read decision rule 

458 var = tree.var_tree[node] 1ab

459 split = tree.split_tree[node] 1ab

460 

461 # get rule boundaries from ancestors. use fill value in case var is 

462 # out of bounds, we don't want to check out of bounds in this function 

463 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1ab

464 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1ab

465 

466 # check rule is in bounds 

467 bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False) 1ab

468 

469 # recurse 

470 if node < tree.var_tree.size // 2: 1ab

471 bad |= _check_recursive( 1ab

472 2 * node, 

473 lower, 

474 upper.at[jnp.where(split, var, max_split.size)].set(split), 

475 ) 

476 bad |= _check_recursive( 1ab

477 2 * node + 1, 

478 lower.at[jnp.where(split, var, max_split.size)].set(split), 

479 upper, 

480 ) 

481 return bad 1ab

482 

483 return ~_check_recursive(1, lower, upper) 1ab

484 

485 

486@check 1ab

487def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

488 """Check that #leaves = 1 + #(internal nodes).""" 

489 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1ab

490 num_leaves = jnp.count_nonzero(is_leaf) 1ab

491 num_internal = jnp.count_nonzero(tree.split_tree) 1ab

492 return num_leaves == num_internal + 1 1ab

493 

494 

495@check 1ab

496def check_var_in_bounds( 1ab

497 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

498) -> Bool[Array, '']: 

499 """Check that variables are in [0, max_split.size).""" 

500 decision_node = tree.split_tree.astype(bool) 1ab

501 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1ab

502 return jnp.all(in_bounds | ~decision_node) 1ab

503 

504 

505@check 1ab

506def check_split_in_bounds( 1ab

507 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

508) -> Bool[Array, '']: 

509 """Check that splits are in [0, max_split[var]].""" 

510 max_split_var = ( 1ab

511 max_split.astype(jnp.int32) 

512 .at[tree.var_tree] 

513 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) 

514 ) 

515 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1ab

516 

517 

518def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 1ab

519 """Check the validity of a tree. 

520 

521 Use `describe_error` to parse the error code returned by this function. 

522 

523 Parameters 

524 ---------- 

525 tree 

526 The tree to check. 

527 max_split 

528 The maximum split value for each variable. 

529 

530 Returns 

531 ------- 

532 An integer where each bit indicates whether a check failed. 

533 """ 

534 error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1ab

535 error = error_type(0) 1ab

536 for i, func in enumerate(check_functions): 1ab

537 ok = func(tree, max_split) 1ab

538 ok = jnp.bool_(ok) 1ab

539 bit = (~ok) << i 1ab

540 error |= bit 1ab

541 return error 1ab

542 

543 

544def describe_error(error: int | Integer[Array, '']) -> list[str]: 1ab

545 """Describe the error code returned by `check_tree`. 

546 

547 Parameters 

548 ---------- 

549 error 

550 The error code returned by `check_tree`. 

551 

552 Returns 

553 ------- 

554 A list of the function names that implement the failed checks. 

555 """ 

556 return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)] 

557 

558 

559@jit 1ab

560@partial(vmap_nodoc, in_axes=(0, None)) 1ab

561def check_trace( 1ab

562 trace: TreeHeaps, max_split: UInt[Array, ' p'] 

563) -> UInt[Array, 'trace_length num_trees']: 

564 """Check the validity of a sequence of sets of trees. 

565 

566 Use `describe_error` to parse the error codes returned by this function. 

567 

568 Parameters 

569 ---------- 

570 trace 

571 The sequence of sets of trees to check. The tree arrays must have 

572 broadcast shape (trace_length, num_trees). This object can have 

573 additional attributes beyond the tree arrays, they are ignored. 

574 max_split 

575 The maximum split value for each variable. 

576 

577 Returns 

578 ------- 

579 A matrix of error codes for each tree. 

580 """ 

581 trees = TreesTrace.from_dataclass(trace) 1ab

582 check_forest = vmap(check_tree, in_axes=(0, None)) 1ab

583 return check_forest(trees, max_split) 1ab

584 

585 

586def _get_next_line(s: str, i: int) -> tuple[str, int]: 1ab

587 """Get the next line from a string and the new index.""" 

588 i_new = s.find('\n', i) 1ab

589 if i_new == -1: 589 ↛ 590line 589 didn't jump to line 590 because the condition on line 589 was never true1ab

590 return s[i:], len(s) 

591 return s[i:i_new], i_new + 1 1ab

592 

593 

594class BARTTraceMeta(Module): 1ab

595 """Metadata of R BART tree traces. 

596 

597 Parameters 

598 ---------- 

599 ndpost 

600 The number of posterior draws. 

601 ntree 

602 The number of trees in the model. 

603 numcut 

604 The maximum split value for each variable. 

605 heap_size 

606 The size of the heap required to store the trees. 

607 """ 

608 

609 ndpost: int = field(static=True) 1ab

610 ntree: int = field(static=True) 1ab

611 numcut: UInt[Array, ' p'] 1ab

612 heap_size: int = field(static=True) 1ab

613 

614 

615def scan_BART_trees(trees: str) -> BARTTraceMeta: 1ab

616 """Scan an R BART tree trace checking for errors and parsing metadata. 

617 

618 Parameters 

619 ---------- 

620 trees 

621 The string representation of a trace of trees of the R BART package. 

622 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

623 

624 Returns 

625 ------- 

626 An object containing the metadata. 

627 

628 Raises 

629 ------ 

630 ValueError 

631 If the string is malformed or contains leftover characters. 

632 """ 

633 # parse first line 

634 line, i_char = _get_next_line(trees, 0) 1ab

635 i_line = 1 1ab

636 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1ab

637 if match is None: 637 ↛ 638line 637 didn't jump to line 638 because the condition on line 637 was never true1ab

638 msg = f'Malformed header at {i_line=}' 

639 raise ValueError(msg) 

640 ndpost, ntree, p = map(int, match.groups()) 1ab

641 

642 # initial values for maxima 

643 max_heap_index = 0 1ab

644 numcut = numpy.zeros(p, int) 1ab

645 

646 # cycle over iterations and trees 

647 for i_iter in range(ndpost): 1ab

648 for i_tree in range(ntree): 1ab

649 # parse first line of tree definition 

650 line, i_char = _get_next_line(trees, i_char) 1ab

651 i_line += 1 1ab

652 match = fullmatch(r'(\d+)', line) 1ab

653 if match is None: 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true1ab

654 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' 

655 raise ValueError(msg) 

656 num_nodes = int(line) 1ab

657 

658 # cycle over nodes 

659 for i_node in range(num_nodes): 1ab

660 # parse node definition 

661 line, i_char = _get_next_line(trees, i_char) 1ab

662 i_line += 1 1ab

663 match = fullmatch( 1ab

664 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line 

665 ) 

666 if match is None: 666 ↛ 667line 666 didn't jump to line 667 because the condition on line 666 was never true1ab

667 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' 

668 raise ValueError(msg) 

669 i_heap = int(match.group(1)) 1ab

670 var = int(match.group(2)) 1ab

671 split = int(match.group(3)) 1ab

672 

673 # update maxima 

674 numcut[var] = max(numcut[var], split) 1ab

675 max_heap_index = max(max_heap_index, i_heap) 1ab

676 

677 assert i_char <= len(trees) 1ab

678 if i_char < len(trees): 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true1ab

679 msg = f'Leftover {len(trees) - i_char} characters in string' 

680 raise ValueError(msg) 

681 

682 # determine minimal integer type for numcut 

683 numcut += 1 # because BART is 0-based 1ab

684 split_dtype = minimal_unsigned_dtype(numcut.max()) 1ab

685 numcut = jnp.array(numcut.astype(split_dtype)) 1ab

686 

687 # determine minimum heap size to store the trees 

688 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1ab

689 

690 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1ab

691 

692 

693class TraceWithOffset(Module): 1ab

694 """Implementation of `bartz.mcmcloop.Trace`.""" 

695 

696 leaf_tree: Float32[Array, 'ndpost ntree 2**d'] 1ab

697 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab

698 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab

699 offset: Float32[Array, ' ndpost'] 1ab

700 

701 @classmethod 1ab

702 def from_trees_trace( 1ab

703 cls, trees: TreeHeaps, offset: Float32[Array, ''] 

704 ) -> 'TraceWithOffset': 

705 """Create a `TraceWithOffset` from a `TreeHeaps`.""" 

706 ndpost, _, _ = trees.leaf_tree.shape 1ab

707 return cls( 1ab

708 leaf_tree=trees.leaf_tree, 

709 var_tree=trees.var_tree, 

710 split_tree=trees.split_tree, 

711 offset=jnp.full(ndpost, offset), 

712 ) 

713 

714 

715def trees_BART_to_bartz( 1ab

716 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None 

717) -> tuple[TraceWithOffset, BARTTraceMeta]: 

718 """Convert trees from the R BART format to the bartz format. 

719 

720 Parameters 

721 ---------- 

722 trees 

723 The string representation of a trace of trees of the R BART package. 

724 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

725 min_maxdepth 

726 The maximum tree depth of the output will be set to the maximum 

727 observed depth in the input trees. Use this parameter to require at 

728 least this maximum depth in the output format. 

729 offset 

730 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be 

731 summed to the sum of trees. To match that behavior, this function 

732 returns an offset as well, zero by default. Set with this parameter 

733 otherwise. 

734 

735 Returns 

736 ------- 

737 trace : TraceWithOffset 

738 A representation of the trees compatible with the trace returned by 

739 `bartz.mcmcloop.run_mcmc`. 

740 meta : BARTTraceMeta 

741 The metadata of the trace, containing the number of iterations, trees, 

742 and the maximum split value. 

743 """ 

744 # scan all the string checking for errors and determining sizes 

745 meta = scan_BART_trees(trees) 1ab

746 

747 # skip first line 

748 _, i_char = _get_next_line(trees, 0) 1ab

749 

750 heap_size = max(meta.heap_size, 2**min_maxdepth) 1ab

751 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1ab

752 var_trees = numpy.zeros( 1ab

753 (meta.ndpost, meta.ntree, heap_size // 2), 

754 dtype=minimal_unsigned_dtype(meta.numcut.size - 1), 

755 ) 

756 split_trees = numpy.zeros( 1ab

757 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype 

758 ) 

759 

760 # cycle over iterations and trees 

761 for i_iter in range(meta.ndpost): 1ab

762 for i_tree in range(meta.ntree): 1ab

763 # parse first line of tree definition 

764 line, i_char = _get_next_line(trees, i_char) 1ab

765 num_nodes = int(line) 1ab

766 

767 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1ab

768 

769 # cycle over nodes 

770 for _ in range(num_nodes): 1ab

771 # parse node definition 

772 line, i_char = _get_next_line(trees, i_char) 1ab

773 values = line.split() 1ab

774 i_heap = int(values[0]) 1ab

775 var = int(values[1]) 1ab

776 split = int(values[2]) 1ab

777 leaf = float(values[3]) 1ab

778 

779 # update values 

780 leaf_trees[i_iter, i_tree, i_heap] = leaf 1ab

781 is_internal[i_heap // 2] = True 1ab

782 if i_heap < heap_size // 2: 1ab

783 var_trees[i_iter, i_tree, i_heap] = var 1ab

784 split_trees[i_iter, i_tree, i_heap] = split + 1 1ab

785 

786 is_internal[0] = False 1ab

787 split_trees[i_iter, i_tree, ~is_internal] = 0 1ab

788 

789 return TraceWithOffset( 1ab

790 leaf_tree=jnp.array(leaf_trees), 

791 var_tree=jnp.array(var_trees), 

792 split_tree=jnp.array(split_trees), 

793 offset=jnp.zeros(meta.ndpost) 

794 if offset is None 

795 else jnp.full(meta.ndpost, offset), 

796 ), meta 

797 

798 

799class SamplePriorStack(Module): 1ab

800 """Represent the manually managed stack used in `sample_prior`. 

801 

802 Each level of the stack represents a recursion into a child node in a 

803 binary tree of maximum depth `d`. 

804 

805 Parameters 

806 ---------- 

807 nonterminal 

808 Whether the node is valid or the recursion is into unused node slots. 

809 lower 

810 upper 

811 The available cutpoints along ``var`` are in the integer range 

812 ``[1 + lower[var], 1 + upper[var])``. 

813 var 

814 split 

815 The variable and cutpoint of a decision node. 

816 """ 

817 

818 nonterminal: Bool[Array, ' d-1'] 1ab

819 lower: UInt[Array, 'd-1 p'] 1ab

820 upper: UInt[Array, 'd-1 p'] 1ab

821 var: UInt[Array, ' d-1'] 1ab

822 split: UInt[Array, ' d-1'] 1ab

823 

824 @classmethod 1ab

825 def initial( 1ab

826 cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p'] 

827 ) -> 'SamplePriorStack': 

828 """Initialize the stack. 

829 

830 Parameters 

831 ---------- 

832 p_nonterminal 

833 The prior probability of a node being non-terminal conditional on 

834 its ancestors and on having available decision rules, at each depth. 

835 max_split 

836 The number of cutpoints along each variable. 

837 

838 Returns 

839 ------- 

840 A `SamplePriorStack` initialized to start the recursion. 

841 """ 

842 var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab

843 return cls( 1ab

844 nonterminal=jnp.ones(p_nonterminal.size, bool), 

845 lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype), 

846 upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)), 

847 var=jnp.zeros(p_nonterminal.size, var_dtype), 

848 split=jnp.zeros(p_nonterminal.size, max_split.dtype), 

849 ) 

850 

851 

852class SamplePriorTrees(Module): 1ab

853 """Object holding the trees generated by `sample_prior`. 

854 

855 Parameters 

856 ---------- 

857 leaf_tree 

858 var_tree 

859 split_tree 

860 The arrays representing the trees, see `bartz.grove`. 

861 """ 

862 

863 leaf_tree: Float32[Array, '* 2**d'] 1ab

864 var_tree: UInt[Array, '* 2**(d-1)'] 1ab

865 split_tree: UInt[Array, '* 2**(d-1)'] 1ab

866 

867 @classmethod 1ab

868 def initial( 1ab

869 cls, 

870 key: Key[Array, ''], 

871 sigma_mu: Float32[Array, ''], 

872 p_nonterminal: Float32[Array, ' d-1'], 

873 max_split: UInt[Array, ' p'], 

874 ) -> 'SamplePriorTrees': 

875 """Initialize the trees. 

876 

877 The leaves are already correct and do not need to be changed. 

878 

879 Parameters 

880 ---------- 

881 key 

882 A jax random key. 

883 sigma_mu 

884 The prior standard deviation of each leaf. 

885 p_nonterminal 

886 The prior probability of a node being non-terminal conditional on 

887 its ancestors and on having available decision rules, at each depth. 

888 max_split 

889 The number of cutpoints along each variable. 

890 

891 Returns 

892 ------- 

893 Trees initialized with random leaves and stub tree structures. 

894 """ 

895 heap_size = 2 ** (p_nonterminal.size + 1) 1ab

896 return cls( 1ab

897 leaf_tree=sigma_mu * random.normal(key, (heap_size,)), 

898 var_tree=jnp.zeros( 

899 heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1) 

900 ), 

901 split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype), 

902 ) 

903 

904 

905class SamplePriorCarry(Module): 1ab

906 """Object holding values carried along the recursion in `sample_prior`. 

907 

908 Parameters 

909 ---------- 

910 key 

911 A jax random key used to sample decision rules. 

912 stack 

913 The stack used to manage the recursion. 

914 trees 

915 The output arrays. 

916 """ 

917 

918 key: Key[Array, ''] 1ab

919 stack: SamplePriorStack 1ab

920 trees: SamplePriorTrees 1ab

921 

922 @classmethod 1ab

923 def initial( 1ab

924 cls, 

925 key: Key[Array, ''], 

926 sigma_mu: Float32[Array, ''], 

927 p_nonterminal: Float32[Array, ' d-1'], 

928 max_split: UInt[Array, ' p'], 

929 ) -> 'SamplePriorCarry': 

930 """Initialize the carry object. 

931 

932 Parameters 

933 ---------- 

934 key 

935 A jax random key. 

936 sigma_mu 

937 The prior standard deviation of each leaf. 

938 p_nonterminal 

939 The prior probability of a node being non-terminal conditional on 

940 its ancestors and on having available decision rules, at each depth. 

941 max_split 

942 The number of cutpoints along each variable. 

943 

944 Returns 

945 ------- 

946 A `SamplePriorCarry` initialized to start the recursion. 

947 """ 

948 keys = split_key(key) 1ab

949 return cls( 1ab

950 keys.pop(), 

951 SamplePriorStack.initial(p_nonterminal, max_split), 

952 SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split), 

953 ) 

954 

955 

956class SamplePriorX(Module): 1ab

957 """Object representing the recursion scan in `sample_prior`. 

958 

959 The sequence of nodes to visit is pre-computed recursively once, unrolling 

960 the recursion schedule. 

961 

962 Parameters 

963 ---------- 

964 node 

965 The heap index of the node to visit. 

966 depth 

967 The depth of the node. 

968 next_depth 

969 The depth of the next node to visit, either the left child or the right 

970 sibling of the node or of an ancestor. 

971 """ 

972 

973 node: Int32[Array, ' 2**(d-1)-1'] 1ab

974 depth: Int32[Array, ' 2**(d-1)-1'] 1ab

975 next_depth: Int32[Array, ' 2**(d-1)-1'] 1ab

976 

977 @classmethod 1ab

978 def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': 1ab

979 """Initialize the sequence of nodes to visit. 

980 

981 Parameters 

982 ---------- 

983 p_nonterminal 

984 The prior probability of a node being non-terminal conditional on 

985 its ancestors and on having available decision rules, at each depth. 

986 

987 Returns 

988 ------- 

989 A `SamplePriorX` initialized with the sequence of nodes to visit. 

990 """ 

991 seq = cls._sequence(p_nonterminal.size) 1ab

992 assert len(seq) == 2**p_nonterminal.size - 1 1ab

993 node = [node for node, depth in seq] 1ab

994 depth = [depth for node, depth in seq] 1ab

995 next_depth = depth[1:] + [p_nonterminal.size] 1ab

996 return cls( 1ab

997 node=jnp.array(node), 

998 depth=jnp.array(depth), 

999 next_depth=jnp.array(next_depth), 

1000 ) 

1001 

1002 @classmethod 1ab

1003 def _sequence( 1ab

1004 cls, max_depth: int, depth: int = 0, node: int = 1 

1005 ) -> tuple[tuple[int, int], ...]: 

1006 """Recursively generate a sequence [(node, depth), ...].""" 

1007 if depth < max_depth: 1ab

1008 out = ((node, depth),) 1ab

1009 out += cls._sequence(max_depth, depth + 1, 2 * node) 1ab

1010 out += cls._sequence(max_depth, depth + 1, 2 * node + 1) 1ab

1011 return out 1ab

1012 return () 1ab

1013 

1014 

1015def sample_prior_onetree( 1ab

1016 key: Key[Array, ''], 

1017 max_split: UInt[Array, ' p'], 

1018 p_nonterminal: Float32[Array, ' d-1'], 

1019 sigma_mu: Float32[Array, ''], 

1020) -> SamplePriorTrees: 

1021 """Sample a tree from the BART prior. 

1022 

1023 Parameters 

1024 ---------- 

1025 key 

1026 A jax random key. 

1027 max_split 

1028 The maximum split value for each variable. 

1029 p_nonterminal 

1030 The prior probability of a node being non-terminal conditional on 

1031 its ancestors and on having available decision rules, at each depth. 

1032 sigma_mu 

1033 The prior standard deviation of each leaf. 

1034 

1035 Returns 

1036 ------- 

1037 An object containing a generated tree. 

1038 """ 

1039 carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) 1ab

1040 xs = SamplePriorX.initial(p_nonterminal) 1ab

1041 

1042 def loop(carry: SamplePriorCarry, x: SamplePriorX): 1ab

1043 keys = split_key(carry.key, 4) 1ab

1044 

1045 # get variables at current stack level 

1046 stack = carry.stack 1ab

1047 nonterminal = stack.nonterminal[x.depth] 1ab

1048 lower = stack.lower[x.depth, :] 1ab

1049 upper = stack.upper[x.depth, :] 1ab

1050 

1051 # sample a random decision rule 

1052 available: Bool[Array, ' p'] = lower < upper 1ab

1053 allowed = jnp.any(available) 1ab

1054 var = randint_masked(keys.pop(), available) 1ab

1055 split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) 1ab

1056 

1057 # cast to shorter integer types 

1058 var = var.astype(carry.trees.var_tree.dtype) 1ab

1059 split = split.astype(carry.trees.split_tree.dtype) 1ab

1060 

1061 # decide whether to try to grow the node if it is growable 

1062 pnt = p_nonterminal[x.depth] 1ab

1063 try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) 1ab

1064 nonterminal &= try_nonterminal & allowed 1ab

1065 

1066 # update trees 

1067 trees = carry.trees 1ab

1068 trees = replace( 1ab

1069 trees, 

1070 var_tree=trees.var_tree.at[x.node].set(var), 

1071 split_tree=trees.split_tree.at[x.node].set( 

1072 jnp.where(nonterminal, split, 0) 

1073 ), 

1074 ) 

1075 

1076 def write_push_stack() -> SamplePriorStack: 1ab

1077 """Update the stack to go to the left child.""" 

1078 return replace( 1ab

1079 stack, 

1080 nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal), 

1081 lower=stack.lower.at[x.next_depth, :].set(lower), 

1082 upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)), 

1083 var=stack.var.at[x.depth].set(var), 

1084 split=stack.split.at[x.depth].set(split), 

1085 ) 

1086 

1087 def pop_push_stack() -> SamplePriorStack: 1ab

1088 """Update the stack to go to the right sibling, possibly at lower depth.""" 

1089 var = stack.var[x.next_depth - 1] 1ab

1090 split = stack.split[x.next_depth - 1] 1ab

1091 lower = stack.lower[x.next_depth - 1, :] 1ab

1092 upper = stack.upper[x.next_depth - 1, :] 1ab

1093 return replace( 1ab

1094 stack, 

1095 lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)), 

1096 upper=stack.upper.at[x.next_depth, :].set(upper), 

1097 ) 

1098 

1099 # update stack 

1100 stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) 1ab

1101 

1102 # update carry 

1103 carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) 1ab

1104 return carry, None 1ab

1105 

1106 carry, _ = lax.scan(loop, carry, xs) 1ab

1107 return carry.trees 1ab

1108 

1109 

1110@partial(vmap_nodoc, in_axes=(0, None, None, None)) 1ab

1111def sample_prior_forest( 1ab

1112 keys: Key[Array, ' num_trees'], 

1113 max_split: UInt[Array, ' p'], 

1114 p_nonterminal: Float32[Array, ' d-1'], 

1115 sigma_mu: Float32[Array, ''], 

1116) -> SamplePriorTrees: 

1117 """Sample a set of independent trees from the BART prior. 

1118 

1119 Parameters 

1120 ---------- 

1121 keys 

1122 A sequence of jax random keys, one for each tree. This determined the 

1123 number of trees sampled. 

1124 max_split 

1125 The maximum split value for each variable. 

1126 p_nonterminal 

1127 The prior probability of a node being non-terminal conditional on 

1128 its ancestors and on having available decision rules, at each depth. 

1129 sigma_mu 

1130 The prior standard deviation of each leaf. 

1131 

1132 Returns 

1133 ------- 

1134 An object containing the generated trees. 

1135 """ 

1136 return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) 1ab

1137 

1138 

1139@partial(jit, static_argnums=(1, 2)) 1ab

1140def sample_prior( 1ab

1141 key: Key[Array, ''], 

1142 trace_length: int, 

1143 num_trees: int, 

1144 max_split: UInt[Array, ' p'], 

1145 p_nonterminal: Float32[Array, ' d-1'], 

1146 sigma_mu: Float32[Array, ''], 

1147) -> SamplePriorTrees: 

1148 """Sample independent trees from the BART prior. 

1149 

1150 Parameters 

1151 ---------- 

1152 key 

1153 A jax random key. 

1154 trace_length 

1155 The number of iterations. 

1156 num_trees 

1157 The number of trees for each iteration. 

1158 max_split 

1159 The number of cutpoints along each variable. 

1160 p_nonterminal 

1161 The prior probability of a node being non-terminal conditional on 

1162 its ancestors and on having available decision rules, at each depth. 

1163 This determines the maximum depth of the trees. 

1164 sigma_mu 

1165 The prior standard deviation of each leaf. 

1166 

1167 Returns 

1168 ------- 

1169 An object containing the generated trees, with batch shape (trace_length, num_trees). 

1170 """ 

1171 keys = random.split(key, trace_length * num_trees) 1ab

1172 trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) 1ab

1173 return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) 1ab

1174 

1175 

1176class debug_gbart(gbart): 1ab

1177 """A subclass of `gbart` that adds debugging functionality. 

1178 

1179 Parameters 

1180 ---------- 

1181 *args 

1182 Passed to `gbart`. 

1183 check_trees 

1184 If `True`, check all trees with `check_trace` after running the MCMC, 

1185 and assert that they are all valid. Set to `False` to allow jax tracing. 

1186 **kw 

1187 Passed to `gbart`. 

1188 """ 

1189 

1190 def __init__(self, *args, check_trees: bool = True, **kw): 1ab

1191 super().__init__(*args, **kw) 1ab

1192 if check_trees: 1ab

1193 bad = self.check_trees() 1ab

1194 bad_count = jnp.count_nonzero(bad) 1ab

1195 assert bad_count == 0 1ab

1196 

1197 def show_tree(self, i_sample: int, i_tree: int, print_all: bool = False): 1ab

1198 """Print a single tree in human-readable format. 

1199 

1200 Parameters 

1201 ---------- 

1202 i_sample 

1203 The index of the posterior sample. 

1204 i_tree 

1205 The index of the tree in the sample. 

1206 print_all 

1207 If `True`, also print the content of unused node slots. 

1208 """ 

1209 tree = TreesTrace.from_dataclass(self._main_trace) 

1210 tree = tree_map(lambda x: x[i_sample, i_tree, :], tree) 

1211 s = format_tree(tree, print_all=print_all) 

1212 print(s) # noqa: T201, this method is intended for debug 

1213 

1214 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, '']: 1ab

1215 """Return the harmonic mean of the error variance. 

1216 

1217 Parameters 

1218 ---------- 

1219 prior 

1220 If `True`, use the prior distribution, otherwise use the full 

1221 conditional at the last MCMC iteration. 

1222 

1223 Returns 

1224 ------- 

1225 The harmonic mean 1/E[1/sigma^2] in the selected distribution. 

1226 """ 

1227 bart = self._mcmc_state 

1228 assert bart.sigma2_alpha is not None 

1229 assert bart.z is None 

1230 if prior: 

1231 alpha = bart.sigma2_alpha 

1232 beta = bart.sigma2_beta 

1233 else: 

1234 resid = bart.resid 

1235 alpha = bart.sigma2_alpha + resid.size / 2 

1236 norm2 = resid @ resid 

1237 beta = bart.sigma2_beta + norm2 / 2 

1238 sigma2 = beta / alpha 

1239 return jnp.sqrt(sigma2) 

1240 

1241 def compare_resid(self) -> tuple[Float32[Array, ' n'], Float32[Array, ' n']]: 1ab

1242 """Re-compute residuals to compare them with the updated ones. 

1243 

1244 Returns 

1245 ------- 

1246 resid1 : Float32[Array, 'n'] 

1247 The final state of the residuals updated during the MCMC. 

1248 resid2 : Float32[Array, 'n'] 

1249 The residuals computed from the final state of the trees. 

1250 """ 

1251 bart = self._mcmc_state 1ab

1252 resid1 = bart.resid 1ab

1253 

1254 trees = evaluate_forest(bart.X, bart.forest) 1ab

1255 

1256 if bart.z is not None: 1ab

1257 ref = bart.z 1ab

1258 else: 

1259 ref = bart.y 1ab

1260 resid2 = ref - (trees + bart.offset) 1ab

1261 

1262 return resid1, resid2 1ab

1263 

1264 def avg_acc(self) -> tuple[Float32[Array, ''], Float32[Array, '']]: 1ab

1265 """Compute the average acceptance rates of tree moves. 

1266 

1267 Returns 

1268 ------- 

1269 acc_grow : Float32[Array, ''] 

1270 The average acceptance rate of grow moves. 

1271 acc_prune : Float32[Array, ''] 

1272 The average acceptance rate of prune moves. 

1273 """ 

1274 trace = self._main_trace 

1275 

1276 def acc(prefix): 

1277 acc = getattr(trace, f'{prefix}_acc_count') 

1278 prop = getattr(trace, f'{prefix}_prop_count') 

1279 return acc.sum() / prop.sum() 

1280 

1281 return acc('grow'), acc('prune') 

1282 

1283 def avg_prop(self) -> tuple[Float32[Array, ''], Float32[Array, '']]: 1ab

1284 """Compute the average proposal rate of grow and prune moves. 

1285 

1286 Returns 

1287 ------- 

1288 prop_grow : Float32[Array, ''] 

1289 The fraction of times grow was proposed instead of prune. 

1290 prop_prune : Float32[Array, ''] 

1291 The fraction of times prune was proposed instead of grow. 

1292 

1293 Notes 

1294 ----- 

1295 This function does not take into account cases where no move was 

1296 proposed. 

1297 """ 

1298 trace = self._main_trace 

1299 

1300 def prop(prefix): 

1301 return getattr(trace, f'{prefix}_prop_count').sum() 

1302 

1303 pgrow = prop('grow') 

1304 pprune = prop('prune') 

1305 total = pgrow + pprune 

1306 return pgrow / total, pprune / total 

1307 

1308 def avg_move(self) -> tuple[Float32[Array, ''], Float32[Array, '']]: 1ab

1309 """Compute the move rate. 

1310 

1311 Returns 

1312 ------- 

1313 rate_grow : Float32[Array, ''] 

1314 The fraction of times a grow move was proposed and accepted. 

1315 rate_prune : Float32[Array, ''] 

1316 The fraction of times a prune move was proposed and accepted. 

1317 """ 

1318 agrow, aprune = self.avg_acc() 

1319 pgrow, pprune = self.avg_prop() 

1320 return agrow * pgrow, aprune * pprune 

1321 

1322 def depth_distr(self) -> Float32[Array, 'trace_length d']: 1ab

1323 """Histogram of tree depths for each state of the trees. 

1324 

1325 Returns 

1326 ------- 

1327 A matrix where each row contains a histogram of tree depths. 

1328 """ 

1329 return trace_depth_distr(self._main_trace.split_tree) 1ab

1330 

1331 def points_per_decision_node_distr(self) -> Float32[Array, 'trace_length n+1']: 1ab

1332 """Histogram of number of points belonging to parent-of-leaf nodes. 

1333 

1334 Returns 

1335 ------- 

1336 A matrix where each row contains a histogram of number of points. 

1337 """ 

1338 return trace_points_per_decision_node_distr( 1ab

1339 self._main_trace, self._mcmc_state.X 

1340 ) 

1341 

1342 def points_per_leaf_distr(self) -> Float32[Array, 'trace_length n+1']: 1ab

1343 """Histogram of number of points belonging to leaves. 

1344 

1345 Returns 

1346 ------- 

1347 A matrix where each row contains a histogram of number of points. 

1348 """ 

1349 return trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X) 1ab

1350 

1351 def check_trees(self) -> UInt[Array, 'trace_length ntree']: 1ab

1352 """Apply `check_trace` to all the tree draws.""" 

1353 return check_trace(self._main_trace, self._mcmc_state.forest.max_split) 1ab

1354 

1355 def tree_goes_bad(self) -> Bool[Array, 'trace_length ntree']: 1ab

1356 """Find iterations where a tree becomes invalid. 

1357 

1358 Returns 

1359 ------- 

1360 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1. 

1361 """ 

1362 bad = self.check_trees().astype(bool) 

1363 bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)]) 

1364 return bad & ~bad_before