Coverage for src/bartz/debug.py: 87%

442 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-07-31 16:09 +0000

1# bartz/src/bartz/debug.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" 

26 

27from collections.abc import Callable 1ab

28from dataclasses import replace 1ab

29from functools import partial 1ab

30from math import ceil, log2 1ab

31from re import fullmatch 1ab

32 

33import numpy 1ab

34from equinox import Module, field 1ab

35from jax import jit, lax, random, vmap 1ab

36from jax import numpy as jnp 1ab

37from jax.tree_util import tree_map 1ab

38from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt 1ab

39 

40from bartz.BART import FloatLike, gbart, mc_gbart 1ab

41from bartz.grove import ( 1ab

42 TreeHeaps, 

43 evaluate_forest, 

44 is_actual_leaf, 

45 is_leaves_parent, 

46 traverse_tree, 

47 tree_depth, 

48 tree_depths, 

49) 

50from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 1ab

51from bartz.jaxext import split as split_key 1ab

52from bartz.mcmcloop import TreesTrace 1ab

53from bartz.mcmcstep import randint_masked 1ab

54 

55 

56def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 1ab

57 """Convert a tree to a human-readable string. 

58 

59 Parameters 

60 ---------- 

61 tree 

62 A single tree to format. 

63 print_all 

64 If `True`, also print the contents of unused node slots in the arrays. 

65 

66 Returns 

67 ------- 

68 A string representation of the tree. 

69 """ 

70 tee = '├──' 1ab

71 corner = '└──' 1ab

72 join = '│ ' 1ab

73 space = ' ' 1ab

74 down = '┐' 1ab

75 bottom = '╢' # '┨' # 1ab

76 

77 def traverse_tree( 1ab

78 lines: list[str], 

79 index: int, 

80 depth: int, 

81 indent: str, 

82 first_indent: str, 

83 next_indent: str, 

84 unused: bool, 

85 ): 

86 if index >= len(tree.leaf_tree): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1ab

87 return 

88 

89 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1ab

90 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1ab

91 

92 is_leaf = split == 0 1ab

93 left_child = 2 * index 1ab

94 right_child = 2 * index + 1 1ab

95 

96 if print_all: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true1ab

97 if unused: 

98 category = 'unused' 

99 elif is_leaf: 

100 category = 'leaf' 

101 else: 

102 category = 'decision' 

103 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})' 

104 else: 

105 assert not unused 1ab

106 if is_leaf: 1ab

107 node_str = f'{tree.leaf_tree[index]:#.2g}' 1ab

108 else: 

109 node_str = f'x{var} < {split}' 1ab

110 

111 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1ab

112 link = down 1ab

113 elif not print_all and left_child >= len(tree.leaf_tree): 1ab

114 link = bottom 1ab

115 else: 

116 link = ' ' 1ab

117 

118 max_number = len(tree.leaf_tree) - 1 1ab

119 ndigits = len(str(max_number)) 1ab

120 number = str(index).rjust(ndigits) 1ab

121 

122 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1ab

123 

124 indent += next_indent 1ab

125 unused = unused or is_leaf 1ab

126 

127 if unused and not print_all: 1ab

128 return 1ab

129 

130 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1ab

131 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1ab

132 

133 lines = [] 1ab

134 traverse_tree(lines, 1, 0, '', '', '', False) 1ab

135 return '\n'.join(lines) 1ab

136 

137 

138def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: 1ab

139 """Measure the depth of the tree. 

140 

141 Parameters 

142 ---------- 

143 split_tree 

144 The cutpoints of the decision rules. 

145 

146 Returns 

147 ------- 

148 The depth of the deepest leaf in the tree. The root is at depth 0. 

149 """ 

150 # this could be done just with split_tree != 0 

151 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab

152 depth = tree_depths(is_leaf.size) 1ab

153 depth = jnp.where(is_leaf, depth, 0) 1ab

154 return jnp.max(depth) 1ab

155 

156 

157def forest_depth_distr( 1ab

158 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

159) -> Int32[Array, ' d']: 

160 """Histogram the depths of a set of trees. 

161 

162 Parameters 

163 ---------- 

164 split_tree 

165 The cutpoints of the decision rules of the trees. 

166 

167 Returns 

168 ------- 

169 An integer vector where the i-th element counts how many trees have depth i. 

170 """ 

171 depth = tree_depth(split_tree) + 1 1ab

172 depths = vmap(tree_actual_depth)(split_tree) 1ab

173 return jnp.bincount(depths, length=depth) 1ab

174 

175 

176@jit 1ab

177def trace_depth_distr( 1ab

178 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'], 

179) -> Int32[Array, 'trace_length d']: 

180 """Histogram the depths of a sequence of sets of trees. 

181 

182 Parameters 

183 ---------- 

184 split_tree 

185 The cutpoints of the decision rules of the trees. 

186 

187 Returns 

188 ------- 

189 A matrix where element (t,i) counts how many trees have depth i in set t. 

190 """ 

191 return vmap(forest_depth_distr)(split_tree) 1ab

192 

193 

194@vmap_nodoc 1ab

195def chains_depth_distr( 1ab

196 split_tree: UInt[Array, 'nchains trace_length num_trees 2**(d-1)'], 

197) -> Int32[Array, 'nchains trace_length d']: 

198 """Histogram the depths of chains of forests of trees. 

199 

200 Parameters 

201 ---------- 

202 split_tree 

203 The cutpoints of the decision rules of the trees. 

204 

205 Returns 

206 ------- 

207 A tensor where element (c,t,i) counts how many trees have depth i in forest t in chain c. 

208 """ 

209 return trace_depth_distr(split_tree) 1ab

210 

211 

212def points_per_decision_node_distr( 1ab

213 var_tree: UInt[Array, ' 2**(d-1)'], 

214 split_tree: UInt[Array, ' 2**(d-1)'], 

215 X: UInt[Array, 'p n'], 

216) -> Int32[Array, ' n+1']: 

217 """Histogram points-per-node counts. 

218 

219 Count how many parent-of-leaf nodes in a tree select each possible amount 

220 of points. 

221 

222 Parameters 

223 ---------- 

224 var_tree 

225 The variables of the decision rules. 

226 split_tree 

227 The cutpoints of the decision rules. 

228 X 

229 The set of points to count. 

230 

231 Returns 

232 ------- 

233 A vector where the i-th element counts how many next-to-leaf nodes have i points. 

234 """ 

235 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab

236 indices = traverse_tree_X(X, var_tree, split_tree) 1ab

237 indices >>= 1 1ab

238 count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0) 1ab

239 is_parent = is_leaves_parent(split_tree) 1ab

240 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent) 1ab

241 

242 

243def forest_points_per_decision_node_distr( 1ab

244 trees: TreeHeaps, X: UInt[Array, 'p n'] 

245) -> Int32[Array, ' n+1']: 

246 """Histogram points-per-node counts for a set of trees. 

247 

248 Count how many parent-of-leaf nodes in a set of trees select each possible 

249 amount of points. 

250 

251 Parameters 

252 ---------- 

253 trees 

254 The set of trees. The variables must have broadcast shape (num_trees,). 

255 X 

256 The set of points to count. 

257 

258 Returns 

259 ------- 

260 A vector where the i-th element counts how many next-to-leaf nodes have i points. 

261 """ 

262 distr = jnp.zeros(X.shape[1] + 1, int) 1ab

263 

264 def loop(distr, heaps: tuple[Array, Array]): 1ab

265 return distr + points_per_decision_node_distr(*heaps, X), None 1ab

266 

267 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab

268 return distr 1ab

269 

270 

271@jit 1ab

272@partial(vmap_nodoc, in_axes=(0, None)) 1ab

273def chains_points_per_decision_node_distr( 1ab

274 chains: TreeHeaps, X: UInt[Array, 'p n'] 

275) -> Int32[Array, 'nchains trace_length n+1']: 

276 """Separately histogram points-per-node counts over chains of forests of trees. 

277 

278 For each set of trees, count how many parent-of-leaf nodes select each 

279 possible amount of points. 

280 

281 Parameters 

282 ---------- 

283 chains 

284 The chains of forests of trees. The variables must have broadcast shape 

285 (nchains, trace_length, num_trees). 

286 X 

287 The set of points to count. 

288 

289 Returns 

290 ------- 

291 A tensor where element (c,t,i) counts how many next-to-leaf nodes have i points in forest t in chain c. 

292 """ 

293 

294 def loop(_, forests): 1ab

295 return None, forest_points_per_decision_node_distr(forests, X) 1ab

296 

297 _, distr = lax.scan(loop, None, chains) 1ab

298 return distr 1ab

299 

300 

301def points_per_leaf_distr( 1ab

302 var_tree: UInt[Array, ' 2**(d-1)'], 

303 split_tree: UInt[Array, ' 2**(d-1)'], 

304 X: UInt[Array, 'p n'], 

305) -> Int32[Array, ' n+1']: 

306 """Histogram points-per-leaf counts in a tree. 

307 

308 Count how many leaves in a tree select each possible amount of points. 

309 

310 Parameters 

311 ---------- 

312 var_tree 

313 The variables of the decision rules. 

314 split_tree 

315 The cutpoints of the decision rules. 

316 X 

317 The set of points to count. 

318 

319 Returns 

320 ------- 

321 A vector where the i-th element counts how many leaves have i points. 

322 """ 

323 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab

324 indices = traverse_tree_X(X, var_tree, split_tree) 1ab

325 count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1) 1ab

326 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab

327 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf) 1ab

328 

329 

330def forest_points_per_leaf_distr( 1ab

331 trees: TreeHeaps, X: UInt[Array, 'p n'] 

332) -> Int32[Array, ' n+1']: 

333 """Histogram points-per-leaf counts over a set of trees. 

334 

335 Count how many leaves in a set of trees select each possible amount of points. 

336 

337 Parameters 

338 ---------- 

339 trees 

340 The set of trees. The variables must have broadcast shape (num_trees,). 

341 X 

342 The set of points to count. 

343 

344 Returns 

345 ------- 

346 A vector where the i-th element counts how many leaves have i points. 

347 """ 

348 distr = jnp.zeros(X.shape[1] + 1, int) 1ab

349 

350 def loop(distr, heaps: tuple[Array, Array]): 1ab

351 return distr + points_per_leaf_distr(*heaps, X), None 1ab

352 

353 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab

354 return distr 1ab

355 

356 

357@jit 1ab

358@partial(vmap_nodoc, in_axes=(0, None)) 1ab

359def chains_points_per_leaf_distr( 1ab

360 chains: TreeHeaps, X: UInt[Array, 'p n'] 

361) -> Int32[Array, 'nchains trace_length n+1']: 

362 """Separately histogram points-per-leaf counts over chains of forests of trees. 

363 

364 For each set of trees, count how many leaves select each possible amount of 

365 points. 

366 

367 Parameters 

368 ---------- 

369 chains 

370 The chains of forests of trees. The variables must have broadcast shape 

371 (nchains, trace_length, num_trees). 

372 X 

373 The set of points to count. 

374 

375 Returns 

376 ------- 

377 A matrix where element (t,i) counts how many leaves have i points in set t. 

378 """ 

379 

380 def loop(_, forests): 1ab

381 return None, forest_points_per_leaf_distr(forests, X) 1ab

382 

383 _, distr = lax.scan(loop, None, chains) 1ab

384 return distr 1ab

385 

386 

387check_functions = [] 1ab

388 

389 

390CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']] 1ab

391 

392 

393def check(func: CheckFunc) -> CheckFunc: 1ab

394 """Add a function to a list of functions used to check trees. 

395 

396 Use to decorate functions that check whether a tree is valid in some way. 

397 These functions are invoked automatically by `check_tree`, `check_trace` and 

398 `debug_gbart`. 

399 

400 Parameters 

401 ---------- 

402 func 

403 The function to add to the list. It must accept a `TreeHeaps` and a 

404 `max_split` argument, and return a boolean scalar that indicates if the 

405 tree is ok. 

406 

407 Returns 

408 ------- 

409 The function unchanged. 

410 """ 

411 check_functions.append(func) 1ab

412 return func 1ab

413 

414 

415@check 1ab

416def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 1ab

417 """Check that integer types are as small as possible and coherent.""" 

418 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab

419 expected_split_dtype = max_split.dtype 1ab

420 return ( 1ab

421 tree.var_tree.dtype == expected_var_dtype 

422 and tree.split_tree.dtype == expected_split_dtype 

423 ) 

424 

425 

426@check 1ab

427def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001 1ab

428 """Check that array sizes are coherent.""" 

429 return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size 1ab

430 

431 

432@check 1ab

433def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

434 """Check that the unused node slot at index 0 is not dirty.""" 

435 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1ab

436 

437 

438@check 1ab

439def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

440 """Check that all leaf values are not inf of nan.""" 

441 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1ab

442 

443 

444@check 1ab

445def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

446 """Check if there is any marked-non-leaf node with a marked-leaf parent.""" 

447 index = jnp.arange( 1ab

448 2 * tree.split_tree.size, 

449 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), 

450 ) 

451 parent_index = index >> 1 1ab

452 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab

453 parent_is_leaf = tree.split_tree[parent_index] == 0 1ab

454 stray = is_not_leaf & parent_is_leaf 1ab

455 stray = stray.at[1].set(False) 1ab

456 return ~jnp.any(stray) 1ab

457 

458 

459@check 1ab

460def check_rule_consistency( 1ab

461 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

462) -> bool | Bool[Array, '']: 

463 """Check that decision rules define proper subsets of ancestor rules.""" 

464 if tree.var_tree.size < 4: 464 ↛ 465line 464 didn't jump to line 465 because the condition on line 464 was never true1ab

465 return True 

466 

467 # initial boundaries of decision rules. use extreme integers instead of 0, 

468 # max_split to avoid checking if there is something out of bounds. 

469 small = jnp.iinfo(jnp.int32).min 1ab

470 large = jnp.iinfo(jnp.int32).max 1ab

471 lower = jnp.full(max_split.size, small, jnp.int32) 1ab

472 upper = jnp.full(max_split.size, large, jnp.int32) 1ab

473 # specify the type explicitly, otherwise they are weakly types and get 

474 # implicitly converted to split.dtype (typically uint8) in the expressions 

475 

476 def _check_recursive(node, lower, upper): 1ab

477 # read decision rule 

478 var = tree.var_tree[node] 1ab

479 split = tree.split_tree[node] 1ab

480 

481 # get rule boundaries from ancestors. use fill value in case var is 

482 # out of bounds, we don't want to check out of bounds in this function 

483 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1ab

484 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1ab

485 

486 # check rule is in bounds 

487 bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False) 1ab

488 

489 # recurse 

490 if node < tree.var_tree.size // 2: 1ab

491 idx = jnp.where(split, var, max_split.size) 1ab

492 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split)) 1ab

493 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 1ab

494 

495 return bad 1ab

496 

497 return ~_check_recursive(1, lower, upper) 1ab

498 

499 

500@check 1ab

501def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab

502 """Check that #leaves = 1 + #(internal nodes).""" 

503 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1ab

504 num_leaves = jnp.count_nonzero(is_leaf) 1ab

505 num_internal = jnp.count_nonzero(tree.split_tree) 1ab

506 return num_leaves == num_internal + 1 1ab

507 

508 

509@check 1ab

510def check_var_in_bounds( 1ab

511 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

512) -> Bool[Array, '']: 

513 """Check that variables are in [0, max_split.size).""" 

514 decision_node = tree.split_tree.astype(bool) 1ab

515 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1ab

516 return jnp.all(in_bounds | ~decision_node) 1ab

517 

518 

519@check 1ab

520def check_split_in_bounds( 1ab

521 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

522) -> Bool[Array, '']: 

523 """Check that splits are in [0, max_split[var]].""" 

524 max_split_var = ( 1ab

525 max_split.astype(jnp.int32) 

526 .at[tree.var_tree] 

527 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) 

528 ) 

529 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1ab

530 

531 

532def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 1ab

533 """Check the validity of a tree. 

534 

535 Use `describe_error` to parse the error code returned by this function. 

536 

537 Parameters 

538 ---------- 

539 tree 

540 The tree to check. 

541 max_split 

542 The maximum split value for each variable. 

543 

544 Returns 

545 ------- 

546 An integer where each bit indicates whether a check failed. 

547 """ 

548 error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1ab

549 error = error_type(0) 1ab

550 for i, func in enumerate(check_functions): 1ab

551 ok = func(tree, max_split) 1ab

552 ok = jnp.bool_(ok) 1ab

553 bit = (~ok) << i 1ab

554 error |= bit 1ab

555 return error 1ab

556 

557 

558def describe_error(error: int | Integer[Array, '']) -> list[str]: 1ab

559 """Describe the error code returned by `check_tree`. 

560 

561 Parameters 

562 ---------- 

563 error 

564 The error code returned by `check_tree`. 

565 

566 Returns 

567 ------- 

568 A list of the function names that implement the failed checks. 

569 """ 

570 return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)] 

571 

572 

573@jit 1ab

574@partial(vmap_nodoc, in_axes=(0, None)) 1ab

575def check_trace( 1ab

576 trace: TreeHeaps, max_split: UInt[Array, ' p'] 

577) -> UInt[Array, 'trace_length num_trees']: 

578 """Check the validity of a sequence of sets of trees. 

579 

580 Use `describe_error` to parse the error codes returned by this function. 

581 

582 Parameters 

583 ---------- 

584 trace 

585 The sequence of sets of trees to check. The tree arrays must have 

586 broadcast shape (trace_length, num_trees). This object can have 

587 additional attributes beyond the tree arrays, they are ignored. 

588 max_split 

589 The maximum split value for each variable. 

590 

591 Returns 

592 ------- 

593 A matrix of error codes for each tree. 

594 """ 

595 trees = TreesTrace.from_dataclass(trace) 1ab

596 return lax.map(partial(check_tree, max_split=max_split), trees) 1ab

597 

598 

599@partial(vmap_nodoc, in_axes=(0, None)) 1ab

600def check_chains( 1ab

601 chains: TreeHeaps, max_split: UInt[Array, ' p'] 

602) -> UInt[Array, 'nchains trace_length num_trees']: 

603 """Check the validity of sequences of sets of trees. 

604 

605 Use `describe_error` to parse the error codes returned by this function. 

606 

607 Parameters 

608 ---------- 

609 chains 

610 The sequences of sets of trees to check. The tree arrays must have 

611 broadcast shape (nchains, trace_length, num_trees). This object can have 

612 additional attributes beyond the tree arrays, they are ignored. 

613 max_split 

614 The maximum split value for each variable. 

615 

616 Returns 

617 ------- 

618 A tensor of error codes for each tree. 

619 """ 

620 return check_trace(chains, max_split) 1ab

621 

622 

623def _get_next_line(s: str, i: int) -> tuple[str, int]: 1ab

624 """Get the next line from a string and the new index.""" 

625 i_new = s.find('\n', i) 1ab

626 if i_new == -1: 626 ↛ 627line 626 didn't jump to line 627 because the condition on line 626 was never true1ab

627 return s[i:], len(s) 

628 return s[i:i_new], i_new + 1 1ab

629 

630 

631class BARTTraceMeta(Module): 1ab

632 """Metadata of R BART tree traces. 

633 

634 Parameters 

635 ---------- 

636 ndpost 

637 The number of posterior draws. 

638 ntree 

639 The number of trees in the model. 

640 numcut 

641 The maximum split value for each variable. 

642 heap_size 

643 The size of the heap required to store the trees. 

644 """ 

645 

646 ndpost: int = field(static=True) 1ab

647 ntree: int = field(static=True) 1ab

648 numcut: UInt[Array, ' p'] 1ab

649 heap_size: int = field(static=True) 1ab

650 

651 

652def scan_BART_trees(trees: str) -> BARTTraceMeta: 1ab

653 """Scan an R BART tree trace checking for errors and parsing metadata. 

654 

655 Parameters 

656 ---------- 

657 trees 

658 The string representation of a trace of trees of the R BART package. 

659 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

660 

661 Returns 

662 ------- 

663 An object containing the metadata. 

664 

665 Raises 

666 ------ 

667 ValueError 

668 If the string is malformed or contains leftover characters. 

669 """ 

670 # parse first line 

671 line, i_char = _get_next_line(trees, 0) 1ab

672 i_line = 1 1ab

673 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1ab

674 if match is None: 674 ↛ 675line 674 didn't jump to line 675 because the condition on line 674 was never true1ab

675 msg = f'Malformed header at {i_line=}' 

676 raise ValueError(msg) 

677 ndpost, ntree, p = map(int, match.groups()) 1ab

678 

679 # initial values for maxima 

680 max_heap_index = 0 1ab

681 numcut = numpy.zeros(p, int) 1ab

682 

683 # cycle over iterations and trees 

684 for i_iter in range(ndpost): 1ab

685 for i_tree in range(ntree): 1ab

686 # parse first line of tree definition 

687 line, i_char = _get_next_line(trees, i_char) 1ab

688 i_line += 1 1ab

689 match = fullmatch(r'(\d+)', line) 1ab

690 if match is None: 690 ↛ 691line 690 didn't jump to line 691 because the condition on line 690 was never true1ab

691 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' 

692 raise ValueError(msg) 

693 num_nodes = int(line) 1ab

694 

695 # cycle over nodes 

696 for i_node in range(num_nodes): 1ab

697 # parse node definition 

698 line, i_char = _get_next_line(trees, i_char) 1ab

699 i_line += 1 1ab

700 match = fullmatch( 1ab

701 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line 

702 ) 

703 if match is None: 703 ↛ 704line 703 didn't jump to line 704 because the condition on line 703 was never true1ab

704 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' 

705 raise ValueError(msg) 

706 i_heap = int(match.group(1)) 1ab

707 var = int(match.group(2)) 1ab

708 split = int(match.group(3)) 1ab

709 

710 # update maxima 

711 numcut[var] = max(numcut[var], split) 1ab

712 max_heap_index = max(max_heap_index, i_heap) 1ab

713 

714 assert i_char <= len(trees) 1ab

715 if i_char < len(trees): 715 ↛ 716line 715 didn't jump to line 716 because the condition on line 715 was never true1ab

716 msg = f'Leftover {len(trees) - i_char} characters in string' 

717 raise ValueError(msg) 

718 

719 # determine minimal integer type for numcut 

720 numcut += 1 # because BART is 0-based 1ab

721 split_dtype = minimal_unsigned_dtype(numcut.max()) 1ab

722 numcut = jnp.array(numcut.astype(split_dtype)) 1ab

723 

724 # determine minimum heap size to store the trees 

725 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1ab

726 

727 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1ab

728 

729 

730class TraceWithOffset(Module): 1ab

731 """Implementation of `bartz.mcmcloop.Trace`.""" 

732 

733 leaf_tree: Float32[Array, 'ndpost ntree 2**d'] 1ab

734 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab

735 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab

736 offset: Float32[Array, ' ndpost'] 1ab

737 

738 @classmethod 1ab

739 def from_trees_trace( 1ab

740 cls, trees: TreeHeaps, offset: Float32[Array, ''] 

741 ) -> 'TraceWithOffset': 

742 """Create a `TraceWithOffset` from a `TreeHeaps`.""" 

743 ndpost, _, _ = trees.leaf_tree.shape 1ab

744 return cls( 1ab

745 leaf_tree=trees.leaf_tree, 

746 var_tree=trees.var_tree, 

747 split_tree=trees.split_tree, 

748 offset=jnp.full(ndpost, offset), 

749 ) 

750 

751 

752def trees_BART_to_bartz( 1ab

753 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None 

754) -> tuple[TraceWithOffset, BARTTraceMeta]: 

755 """Convert trees from the R BART format to the bartz format. 

756 

757 Parameters 

758 ---------- 

759 trees 

760 The string representation of a trace of trees of the R BART package. 

761 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

762 min_maxdepth 

763 The maximum tree depth of the output will be set to the maximum 

764 observed depth in the input trees. Use this parameter to require at 

765 least this maximum depth in the output format. 

766 offset 

767 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be 

768 summed to the sum of trees. To match that behavior, this function 

769 returns an offset as well, zero by default. Set with this parameter 

770 otherwise. 

771 

772 Returns 

773 ------- 

774 trace : TraceWithOffset 

775 A representation of the trees compatible with the trace returned by 

776 `bartz.mcmcloop.run_mcmc`. 

777 meta : BARTTraceMeta 

778 The metadata of the trace, containing the number of iterations, trees, 

779 and the maximum split value. 

780 """ 

781 # scan all the string checking for errors and determining sizes 

782 meta = scan_BART_trees(trees) 1ab

783 

784 # skip first line 

785 _, i_char = _get_next_line(trees, 0) 1ab

786 

787 heap_size = max(meta.heap_size, 2**min_maxdepth) 1ab

788 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1ab

789 var_trees = numpy.zeros( 1ab

790 (meta.ndpost, meta.ntree, heap_size // 2), 

791 dtype=minimal_unsigned_dtype(meta.numcut.size - 1), 

792 ) 

793 split_trees = numpy.zeros( 1ab

794 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype 

795 ) 

796 

797 # cycle over iterations and trees 

798 for i_iter in range(meta.ndpost): 1ab

799 for i_tree in range(meta.ntree): 1ab

800 # parse first line of tree definition 

801 line, i_char = _get_next_line(trees, i_char) 1ab

802 num_nodes = int(line) 1ab

803 

804 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1ab

805 

806 # cycle over nodes 

807 for _ in range(num_nodes): 1ab

808 # parse node definition 

809 line, i_char = _get_next_line(trees, i_char) 1ab

810 values = line.split() 1ab

811 i_heap = int(values[0]) 1ab

812 var = int(values[1]) 1ab

813 split = int(values[2]) 1ab

814 leaf = float(values[3]) 1ab

815 

816 # update values 

817 leaf_trees[i_iter, i_tree, i_heap] = leaf 1ab

818 is_internal[i_heap // 2] = True 1ab

819 if i_heap < heap_size // 2: 1ab

820 var_trees[i_iter, i_tree, i_heap] = var 1ab

821 split_trees[i_iter, i_tree, i_heap] = split + 1 1ab

822 

823 is_internal[0] = False 1ab

824 split_trees[i_iter, i_tree, ~is_internal] = 0 1ab

825 

826 return TraceWithOffset( 1ab

827 leaf_tree=jnp.array(leaf_trees), 

828 var_tree=jnp.array(var_trees), 

829 split_tree=jnp.array(split_trees), 

830 offset=jnp.zeros(meta.ndpost) 

831 if offset is None 

832 else jnp.full(meta.ndpost, offset), 

833 ), meta 

834 

835 

836class SamplePriorStack(Module): 1ab

837 """Represent the manually managed stack used in `sample_prior`. 

838 

839 Each level of the stack represents a recursion into a child node in a 

840 binary tree of maximum depth `d`. 

841 

842 Parameters 

843 ---------- 

844 nonterminal 

845 Whether the node is valid or the recursion is into unused node slots. 

846 lower 

847 upper 

848 The available cutpoints along ``var`` are in the integer range 

849 ``[1 + lower[var], 1 + upper[var])``. 

850 var 

851 split 

852 The variable and cutpoint of a decision node. 

853 """ 

854 

855 nonterminal: Bool[Array, ' d-1'] 1ab

856 lower: UInt[Array, 'd-1 p'] 1ab

857 upper: UInt[Array, 'd-1 p'] 1ab

858 var: UInt[Array, ' d-1'] 1ab

859 split: UInt[Array, ' d-1'] 1ab

860 

861 @classmethod 1ab

862 def initial( 1ab

863 cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p'] 

864 ) -> 'SamplePriorStack': 

865 """Initialize the stack. 

866 

867 Parameters 

868 ---------- 

869 p_nonterminal 

870 The prior probability of a node being non-terminal conditional on 

871 its ancestors and on having available decision rules, at each depth. 

872 max_split 

873 The number of cutpoints along each variable. 

874 

875 Returns 

876 ------- 

877 A `SamplePriorStack` initialized to start the recursion. 

878 """ 

879 var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab

880 return cls( 1ab

881 nonterminal=jnp.ones(p_nonterminal.size, bool), 

882 lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype), 

883 upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)), 

884 var=jnp.zeros(p_nonterminal.size, var_dtype), 

885 split=jnp.zeros(p_nonterminal.size, max_split.dtype), 

886 ) 

887 

888 

889class SamplePriorTrees(Module): 1ab

890 """Object holding the trees generated by `sample_prior`. 

891 

892 Parameters 

893 ---------- 

894 leaf_tree 

895 var_tree 

896 split_tree 

897 The arrays representing the trees, see `bartz.grove`. 

898 """ 

899 

900 leaf_tree: Float32[Array, '* 2**d'] 1ab

901 var_tree: UInt[Array, '* 2**(d-1)'] 1ab

902 split_tree: UInt[Array, '* 2**(d-1)'] 1ab

903 

904 @classmethod 1ab

905 def initial( 1ab

906 cls, 

907 key: Key[Array, ''], 

908 sigma_mu: Float32[Array, ''], 

909 p_nonterminal: Float32[Array, ' d-1'], 

910 max_split: UInt[Array, ' p'], 

911 ) -> 'SamplePriorTrees': 

912 """Initialize the trees. 

913 

914 The leaves are already correct and do not need to be changed. 

915 

916 Parameters 

917 ---------- 

918 key 

919 A jax random key. 

920 sigma_mu 

921 The prior standard deviation of each leaf. 

922 p_nonterminal 

923 The prior probability of a node being non-terminal conditional on 

924 its ancestors and on having available decision rules, at each depth. 

925 max_split 

926 The number of cutpoints along each variable. 

927 

928 Returns 

929 ------- 

930 Trees initialized with random leaves and stub tree structures. 

931 """ 

932 heap_size = 2 ** (p_nonterminal.size + 1) 1ab

933 return cls( 1ab

934 leaf_tree=sigma_mu * random.normal(key, (heap_size,)), 

935 var_tree=jnp.zeros( 

936 heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1) 

937 ), 

938 split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype), 

939 ) 

940 

941 

942class SamplePriorCarry(Module): 1ab

943 """Object holding values carried along the recursion in `sample_prior`. 

944 

945 Parameters 

946 ---------- 

947 key 

948 A jax random key used to sample decision rules. 

949 stack 

950 The stack used to manage the recursion. 

951 trees 

952 The output arrays. 

953 """ 

954 

955 key: Key[Array, ''] 1ab

956 stack: SamplePriorStack 1ab

957 trees: SamplePriorTrees 1ab

958 

959 @classmethod 1ab

960 def initial( 1ab

961 cls, 

962 key: Key[Array, ''], 

963 sigma_mu: Float32[Array, ''], 

964 p_nonterminal: Float32[Array, ' d-1'], 

965 max_split: UInt[Array, ' p'], 

966 ) -> 'SamplePriorCarry': 

967 """Initialize the carry object. 

968 

969 Parameters 

970 ---------- 

971 key 

972 A jax random key. 

973 sigma_mu 

974 The prior standard deviation of each leaf. 

975 p_nonterminal 

976 The prior probability of a node being non-terminal conditional on 

977 its ancestors and on having available decision rules, at each depth. 

978 max_split 

979 The number of cutpoints along each variable. 

980 

981 Returns 

982 ------- 

983 A `SamplePriorCarry` initialized to start the recursion. 

984 """ 

985 keys = split_key(key) 1ab

986 return cls( 1ab

987 keys.pop(), 

988 SamplePriorStack.initial(p_nonterminal, max_split), 

989 SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split), 

990 ) 

991 

992 

993class SamplePriorX(Module): 1ab

994 """Object representing the recursion scan in `sample_prior`. 

995 

996 The sequence of nodes to visit is pre-computed recursively once, unrolling 

997 the recursion schedule. 

998 

999 Parameters 

1000 ---------- 

1001 node 

1002 The heap index of the node to visit. 

1003 depth 

1004 The depth of the node. 

1005 next_depth 

1006 The depth of the next node to visit, either the left child or the right 

1007 sibling of the node or of an ancestor. 

1008 """ 

1009 

1010 node: Int32[Array, ' 2**(d-1)-1'] 1ab

1011 depth: Int32[Array, ' 2**(d-1)-1'] 1ab

1012 next_depth: Int32[Array, ' 2**(d-1)-1'] 1ab

1013 

1014 @classmethod 1ab

1015 def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': 1ab

1016 """Initialize the sequence of nodes to visit. 

1017 

1018 Parameters 

1019 ---------- 

1020 p_nonterminal 

1021 The prior probability of a node being non-terminal conditional on 

1022 its ancestors and on having available decision rules, at each depth. 

1023 

1024 Returns 

1025 ------- 

1026 A `SamplePriorX` initialized with the sequence of nodes to visit. 

1027 """ 

1028 seq = cls._sequence(p_nonterminal.size) 1ab

1029 assert len(seq) == 2**p_nonterminal.size - 1 1ab

1030 node = [node for node, depth in seq] 1ab

1031 depth = [depth for node, depth in seq] 1ab

1032 next_depth = depth[1:] + [p_nonterminal.size] 1ab

1033 return cls( 1ab

1034 node=jnp.array(node), 

1035 depth=jnp.array(depth), 

1036 next_depth=jnp.array(next_depth), 

1037 ) 

1038 

1039 @classmethod 1ab

1040 def _sequence( 1ab

1041 cls, max_depth: int, depth: int = 0, node: int = 1 

1042 ) -> tuple[tuple[int, int], ...]: 

1043 """Recursively generate a sequence [(node, depth), ...].""" 

1044 if depth < max_depth: 1ab

1045 out = ((node, depth),) 1ab

1046 out += cls._sequence(max_depth, depth + 1, 2 * node) 1ab

1047 out += cls._sequence(max_depth, depth + 1, 2 * node + 1) 1ab

1048 return out 1ab

1049 return () 1ab

1050 

1051 

1052def sample_prior_onetree( 1ab

1053 key: Key[Array, ''], 

1054 max_split: UInt[Array, ' p'], 

1055 p_nonterminal: Float32[Array, ' d-1'], 

1056 sigma_mu: Float32[Array, ''], 

1057) -> SamplePriorTrees: 

1058 """Sample a tree from the BART prior. 

1059 

1060 Parameters 

1061 ---------- 

1062 key 

1063 A jax random key. 

1064 max_split 

1065 The maximum split value for each variable. 

1066 p_nonterminal 

1067 The prior probability of a node being non-terminal conditional on 

1068 its ancestors and on having available decision rules, at each depth. 

1069 sigma_mu 

1070 The prior standard deviation of each leaf. 

1071 

1072 Returns 

1073 ------- 

1074 An object containing a generated tree. 

1075 """ 

1076 carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) 1ab

1077 xs = SamplePriorX.initial(p_nonterminal) 1ab

1078 

1079 def loop(carry: SamplePriorCarry, x: SamplePriorX): 1ab

1080 keys = split_key(carry.key, 4) 1ab

1081 

1082 # get variables at current stack level 

1083 stack = carry.stack 1ab

1084 nonterminal = stack.nonterminal[x.depth] 1ab

1085 lower = stack.lower[x.depth, :] 1ab

1086 upper = stack.upper[x.depth, :] 1ab

1087 

1088 # sample a random decision rule 

1089 available: Bool[Array, ' p'] = lower < upper 1ab

1090 allowed = jnp.any(available) 1ab

1091 var = randint_masked(keys.pop(), available) 1ab

1092 split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) 1ab

1093 

1094 # cast to shorter integer types 

1095 var = var.astype(carry.trees.var_tree.dtype) 1ab

1096 split = split.astype(carry.trees.split_tree.dtype) 1ab

1097 

1098 # decide whether to try to grow the node if it is growable 

1099 pnt = p_nonterminal[x.depth] 1ab

1100 try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) 1ab

1101 nonterminal &= try_nonterminal & allowed 1ab

1102 

1103 # update trees 

1104 trees = carry.trees 1ab

1105 trees = replace( 1ab

1106 trees, 

1107 var_tree=trees.var_tree.at[x.node].set(var), 

1108 split_tree=trees.split_tree.at[x.node].set( 

1109 jnp.where(nonterminal, split, 0) 

1110 ), 

1111 ) 

1112 

1113 def write_push_stack() -> SamplePriorStack: 1ab

1114 """Update the stack to go to the left child.""" 

1115 return replace( 1ab

1116 stack, 

1117 nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal), 

1118 lower=stack.lower.at[x.next_depth, :].set(lower), 

1119 upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)), 

1120 var=stack.var.at[x.depth].set(var), 

1121 split=stack.split.at[x.depth].set(split), 

1122 ) 

1123 

1124 def pop_push_stack() -> SamplePriorStack: 1ab

1125 """Update the stack to go to the right sibling, possibly at lower depth.""" 

1126 var = stack.var[x.next_depth - 1] 1ab

1127 split = stack.split[x.next_depth - 1] 1ab

1128 lower = stack.lower[x.next_depth - 1, :] 1ab

1129 upper = stack.upper[x.next_depth - 1, :] 1ab

1130 return replace( 1ab

1131 stack, 

1132 lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)), 

1133 upper=stack.upper.at[x.next_depth, :].set(upper), 

1134 ) 

1135 

1136 # update stack 

1137 stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) 1ab

1138 

1139 # update carry 

1140 carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) 1ab

1141 return carry, None 1ab

1142 

1143 carry, _ = lax.scan(loop, carry, xs) 1ab

1144 return carry.trees 1ab

1145 

1146 

1147@partial(vmap_nodoc, in_axes=(0, None, None, None)) 1ab

1148def sample_prior_forest( 1ab

1149 keys: Key[Array, ' num_trees'], 

1150 max_split: UInt[Array, ' p'], 

1151 p_nonterminal: Float32[Array, ' d-1'], 

1152 sigma_mu: Float32[Array, ''], 

1153) -> SamplePriorTrees: 

1154 """Sample a set of independent trees from the BART prior. 

1155 

1156 Parameters 

1157 ---------- 

1158 keys 

1159 A sequence of jax random keys, one for each tree. This determined the 

1160 number of trees sampled. 

1161 max_split 

1162 The maximum split value for each variable. 

1163 p_nonterminal 

1164 The prior probability of a node being non-terminal conditional on 

1165 its ancestors and on having available decision rules, at each depth. 

1166 sigma_mu 

1167 The prior standard deviation of each leaf. 

1168 

1169 Returns 

1170 ------- 

1171 An object containing the generated trees. 

1172 """ 

1173 return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) 1ab

1174 

1175 

1176@partial(jit, static_argnums=(1, 2)) 1ab

1177def sample_prior( 1ab

1178 key: Key[Array, ''], 

1179 trace_length: int, 

1180 num_trees: int, 

1181 max_split: UInt[Array, ' p'], 

1182 p_nonterminal: Float32[Array, ' d-1'], 

1183 sigma_mu: Float32[Array, ''], 

1184) -> SamplePriorTrees: 

1185 """Sample independent trees from the BART prior. 

1186 

1187 Parameters 

1188 ---------- 

1189 key 

1190 A jax random key. 

1191 trace_length 

1192 The number of iterations. 

1193 num_trees 

1194 The number of trees for each iteration. 

1195 max_split 

1196 The number of cutpoints along each variable. 

1197 p_nonterminal 

1198 The prior probability of a node being non-terminal conditional on 

1199 its ancestors and on having available decision rules, at each depth. 

1200 This determines the maximum depth of the trees. 

1201 sigma_mu 

1202 The prior standard deviation of each leaf. 

1203 

1204 Returns 

1205 ------- 

1206 An object containing the generated trees, with batch shape (trace_length, num_trees). 

1207 """ 

1208 keys = random.split(key, trace_length * num_trees) 1ab

1209 trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) 1ab

1210 return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) 1ab

1211 

1212 

1213@partial(jit, static_argnames=('sum_trees',)) 1ab

1214def evaluate_forests( 1ab

1215 X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True 

1216) -> Float32[Array, 'nforests n'] | Float32[Array, 'nforests num_trees n']: 

1217 """ 

1218 Evaluate ensembles of trees at an array of points. 

1219 

1220 Parameters 

1221 ---------- 

1222 X 

1223 The coordinates to evaluate the trees at. 

1224 trees 

1225 The tree heaps, with batch shape (nforests, num_trees). 

1226 sum_trees 

1227 Whether to sum the values in each forest. 

1228 

1229 Returns 

1230 ------- 

1231 The (sum of) the values of the trees at the points in `X`. 

1232 """ 

1233 

1234 @partial(vmap, in_axes=(None, 0)) 1ab

1235 def _evaluate_forests(X, trees): 1ab

1236 return evaluate_forest(X, trees, sum_trees=sum_trees) 1ab

1237 

1238 return _evaluate_forests(X, trees) 1ab

1239 

1240 

1241class debug_mc_gbart(mc_gbart): 1ab

1242 """A subclass of `mc_gbart` that adds debugging functionality. 

1243 

1244 Parameters 

1245 ---------- 

1246 *args 

1247 Passed to `mc_gbart`. 

1248 check_trees 

1249 If `True`, check all trees with `check_trace` after running the MCMC, 

1250 and assert that they are all valid. Set to `False` to allow jax tracing. 

1251 **kw 

1252 Passed to `mc_gbart`. 

1253 """ 

1254 

1255 def __init__(self, *args, check_trees: bool = True, **kw): 1ab

1256 super().__init__(*args, **kw) 1ab

1257 if check_trees: 1ab

1258 bad = self.check_trees() 1ab

1259 bad_count = jnp.count_nonzero(bad) 1ab

1260 assert bad_count == 0 1ab

1261 

1262 def print_tree( 1ab

1263 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False 

1264 ): 

1265 """Print a single tree in human-readable format. 

1266 

1267 Parameters 

1268 ---------- 

1269 i_chain 

1270 The index of the MCMC chain. 

1271 i_sample 

1272 The index of the (post-burnin) sample in the chain. 

1273 i_tree 

1274 The index of the tree in the sample. 

1275 print_all 

1276 If `True`, also print the content of unused node slots. 

1277 """ 

1278 tree = TreesTrace.from_dataclass(self._main_trace) 

1279 tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree) 

1280 s = format_tree(tree, print_all=print_all) 

1281 print(s) # noqa: T201, this method is intended for debug 

1282 

1283 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']: 1ab

1284 """Return the harmonic mean of the error variance. 

1285 

1286 Parameters 

1287 ---------- 

1288 prior 

1289 If `True`, use the prior distribution, otherwise use the full 

1290 conditional at the last MCMC iteration. 

1291 

1292 Returns 

1293 ------- 

1294 The harmonic mean 1/E[1/sigma^2] in the selected distribution. 

1295 """ 

1296 bart = self._mcmc_state 

1297 assert bart.sigma2_alpha is not None 

1298 assert bart.z is None 

1299 if prior: 

1300 alpha = bart.sigma2_alpha 

1301 beta = bart.sigma2_beta 

1302 else: 

1303 alpha = bart.sigma2_alpha + bart.resid.size / 2 

1304 norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid) 

1305 beta = bart.sigma2_beta + norm2 / 2 

1306 sigma2 = beta / alpha 

1307 return jnp.sqrt(sigma2) 

1308 

1309 def compare_resid( 1ab

1310 self, 

1311 ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]: 

1312 """Re-compute residuals to compare them with the updated ones. 

1313 

1314 Returns 

1315 ------- 

1316 resid1 : Float32[Array, 'mc_cores n'] 

1317 The final state of the residuals updated during the MCMC. 

1318 resid2 : Float32[Array, 'mc_cores n'] 

1319 The residuals computed from the final state of the trees. 

1320 """ 

1321 bart = self._mcmc_state 1ab

1322 resid1 = bart.resid 1ab

1323 

1324 forests = TreesTrace.from_dataclass(bart.forest) 1ab

1325 trees = evaluate_forests(bart.X, forests) 1ab

1326 

1327 if bart.z is not None: 1ab

1328 ref = bart.z 1ab

1329 else: 

1330 ref = bart.y 1ab

1331 resid2 = ref - (trees + bart.offset) 1ab

1332 

1333 return resid1, resid2 1ab

1334 

1335 def avg_acc( 1ab

1336 self, 

1337 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

1338 """Compute the average acceptance rates of tree moves. 

1339 

1340 Returns 

1341 ------- 

1342 acc_grow : Float32[Array, 'mc_cores'] 

1343 The average acceptance rate of grow moves. 

1344 acc_prune : Float32[Array, 'mc_cores'] 

1345 The average acceptance rate of prune moves. 

1346 """ 

1347 trace = self._main_trace 

1348 

1349 def acc(prefix): 

1350 acc = getattr(trace, f'{prefix}_acc_count') 

1351 prop = getattr(trace, f'{prefix}_prop_count') 

1352 return acc.sum(axis=1) / prop.sum(axis=1) 

1353 

1354 return acc('grow'), acc('prune') 

1355 

1356 def avg_prop( 1ab

1357 self, 

1358 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

1359 """Compute the average proposal rate of grow and prune moves. 

1360 

1361 Returns 

1362 ------- 

1363 prop_grow : Float32[Array, 'mc_cores'] 

1364 The fraction of times grow was proposed instead of prune. 

1365 prop_prune : Float32[Array, 'mc_cores'] 

1366 The fraction of times prune was proposed instead of grow. 

1367 

1368 Notes 

1369 ----- 

1370 This function does not take into account cases where no move was 

1371 proposed. 

1372 """ 

1373 trace = self._main_trace 

1374 

1375 def prop(prefix): 

1376 return getattr(trace, f'{prefix}_prop_count').sum(axis=1) 

1377 

1378 pgrow = prop('grow') 

1379 pprune = prop('prune') 

1380 total = pgrow + pprune 

1381 return pgrow / total, pprune / total 

1382 

1383 def avg_move( 1ab

1384 self, 

1385 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

1386 """Compute the move rate. 

1387 

1388 Returns 

1389 ------- 

1390 rate_grow : Float32[Array, 'mc_cores'] 

1391 The fraction of times a grow move was proposed and accepted. 

1392 rate_prune : Float32[Array, 'mc_cores'] 

1393 The fraction of times a prune move was proposed and accepted. 

1394 """ 

1395 agrow, aprune = self.avg_acc() 

1396 pgrow, pprune = self.avg_prop() 

1397 return agrow * pgrow, aprune * pprune 

1398 

1399 def depth_distr(self) -> Float32[Array, 'mc_cores ndpost/mc_cores d']: 1ab

1400 """Histogram of tree depths for each state of the trees. 

1401 

1402 Returns 

1403 ------- 

1404 A matrix where each row contains a histogram of tree depths. 

1405 """ 

1406 return chains_depth_distr(self._main_trace.split_tree) 1ab

1407 

1408 def points_per_decision_node_distr( 1ab

1409 self, 

1410 ) -> Float32[Array, 'mc_cores ndpost/mc_cores n+1']: 

1411 """Histogram of number of points belonging to parent-of-leaf nodes. 

1412 

1413 Returns 

1414 ------- 

1415 A matrix where each row contains a histogram of number of points. 

1416 """ 

1417 return chains_points_per_decision_node_distr( 1ab

1418 self._main_trace, self._mcmc_state.X 

1419 ) 

1420 

1421 def points_per_leaf_distr(self) -> Float32[Array, 'mc_cores ndpost/mc_cores n+1']: 1ab

1422 """Histogram of number of points belonging to leaves. 

1423 

1424 Returns 

1425 ------- 

1426 A matrix where each row contains a histogram of number of points. 

1427 """ 

1428 return chains_points_per_leaf_distr(self._main_trace, self._mcmc_state.X) 1ab

1429 

1430 def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']: 1ab

1431 """Apply `check_trace` to all the tree draws.""" 

1432 return check_chains(self._main_trace, self._mcmc_state.forest.max_split) 1ab

1433 

1434 def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']: 1ab

1435 """Find iterations where a tree becomes invalid. 

1436 

1437 Returns 

1438 ------- 

1439 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1. 

1440 """ 

1441 bad = self.check_trees().astype(bool) 

1442 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)]) 

1443 return bad & ~bad_before 

1444 

1445 

1446class debug_gbart(debug_mc_gbart, gbart): 1ab

1447 """A subclass of `gbart` that adds debugging functionality. 

1448 

1449 Parameters 

1450 ---------- 

1451 *args 

1452 Passed to `gbart`. 

1453 check_trees 

1454 If `True`, check all trees with `check_trace` after running the MCMC, 

1455 and assert that they are all valid. Set to `False` to allow jax tracing. 

1456 **kw 

1457 Passed to `gbart`. 

1458 """