Coverage for src/bartz/BART.py: 78%

236 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1# bartz/src/bartz/BART.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement a user interface that mimics the R BART package.""" 

26 

27import functools 1ab

28import math 1ab

29from typing import Any, Literal 1ab

30 

31import jax 1ab

32import jax.numpy as jnp 1ab

33from jax.scipy.special import ndtri 1ab

34from jaxtyping import Array, Bool, Float, Float32 1ab

35 

36from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars 1ab

37 

38FloatLike = float | Float[Any, ''] 1ab

39 

40 

41class gbart: 1ab

42 """ 

43 Nonparametric regression with Bayesian Additive Regression Trees (BART). 

44 

45 Regress `y_train` on `x_train` with a latent mean function represented as 

46 a sum of decision trees. The inference is carried out by sampling the 

47 posterior distribution of the tree ensemble with an MCMC. 

48 

49 Parameters 

50 ---------- 

51 x_train : array (p, n) or DataFrame 

52 The training predictors. 

53 y_train : array (n,) or Series 

54 The training responses. 

55 x_test : array (p, m) or DataFrame, optional 

56 The test predictors. 

57 type 

58 The type of regression. 'wbart' for continuous regression, 'pbart' for 

59 binary regression with probit link. 

60 usequants : bool, default False 

61 Whether to use predictors quantiles instead of a uniform grid to bin 

62 predictors. 

63 sigest : float, optional 

64 An estimate of the residual standard deviation on `y_train`, used to set 

65 `lamda`. If not specified, it is estimated by linear regression (with 

66 intercept, and without taking into account `w`). If `y_train` has less 

67 than two elements, it is set to 1. If n <= p, it is set to the standard 

68 deviation of `y_train`. Ignored if `lamda` is specified. 

69 sigdf : int, default 3 

70 The degrees of freedom of the scaled inverse-chisquared prior on the 

71 noise variance. 

72 sigquant : float, default 0.9 

73 The quantile of the prior on the noise variance that shall match 

74 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

75 k : float, default 2 

76 The inverse scale of the prior standard deviation on the latent mean 

77 function, relative to half the observed range of `y_train`. If `y_train` 

78 has less than two elements, `k` is ignored and the scale is set to 1. 

79 power : float, default 2 

80 base : float, default 0.95 

81 Parameters of the prior on tree node generation. The probability that a 

82 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

83 power``. 

84 lamda 

85 The prior harmonic mean of the error variance. (The harmonic mean of x 

86 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

87 `sigquant`. 

88 tau_num 

89 The numerator in the expression that determines the prior standard 

90 deviation of leaves. If not specified, default to ``(max(y_train) - 

91 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

92 continuous regression, and 3 for binary regression. 

93 offset 

94 The prior mean of the latent mean function. If not specified, it is set 

95 to the mean of `y_train` for continuous regression, and to 

96 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

97 `offset` is set to 0. 

98 w : array (n,), optional 

99 Coefficients that rescale the error standard deviation on each 

100 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

101 datapoints. Note: `w` is ignored in the automatic determination of 

102 `sigest`, so either the weights should be O(1), or `sigest` should be 

103 specified by the user. 

104 ntree : int, default 200 

105 The number of trees used to represent the latent mean function. 

106 numcut : int, default 255 

107 If `usequants` is `False`: the exact number of cutpoints used to bin the 

108 predictors, ranging between the minimum and maximum observed values 

109 (excluded). 

110 

111 If `usequants` is `True`: the maximum number of cutpoints to use for 

112 binning the predictors. Each predictor is binned such that its 

113 distribution in `x_train` is approximately uniform across bins. The 

114 number of bins is at most the number of unique values appearing in 

115 `x_train`, or ``numcut + 1``. 

116 

117 Before running the algorithm, the predictors are compressed to the 

118 smallest integer type that fits the bin indices, so `numcut` is best set 

119 to the maximum value of an unsigned integer type. 

120 ndpost : int, default 1000 

121 The number of MCMC samples to save, after burn-in. 

122 nskip : int, default 100 

123 The number of initial MCMC samples to discard as burn-in. 

124 keepevery : int, default 1 

125 The thinning factor for the MCMC samples, after burn-in. 

126 printevery : int or None, default 100 

127 The number of iterations (including thinned-away ones) between each log 

128 line. Set to `None` to disable logging. 

129 

130 `printevery` has a few unexpected side effects. On cpu, interrupting 

131 with ^C halts the MCMC only on the next log. And the total number of 

132 iterations is a multiple of `printevery`, so if ``nskip + keepevery * 

133 ndpost`` is not a multiple of `printevery`, some of the last iterations 

134 will not be saved. 

135 seed : int or jax random key, default 0 

136 The seed for the random number generator. 

137 maxdepth : int, default 6 

138 The maximum depth of the trees. This is 1-based, so with the default 

139 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

140 init_kw : dict 

141 Additional arguments passed to `mcmcstep.init`. 

142 run_mcmc_kw : dict 

143 Additional arguments passed to `mcmcloop.run_mcmc`. 

144 

145 Attributes 

146 ---------- 

147 yhat_train : array (ndpost, n) 

148 The conditional posterior mean at `x_train` for each MCMC iteration. 

149 yhat_train_mean : array (n,) 

150 The marginal posterior mean at `x_train`. 

151 yhat_test : array (ndpost, m) 

152 The conditional posterior mean at `x_test` for each MCMC iteration. 

153 yhat_test_mean : array (m,) 

154 The marginal posterior mean at `x_test`. 

155 sigma : array (ndpost,) 

156 The standard deviation of the error. 

157 first_sigma : array (nskip,) 

158 The standard deviation of the error in the burn-in phase. 

159 offset : float 

160 The prior mean of the latent mean function. 

161 sigest : float or None 

162 The estimated standard deviation of the error used to set `lamda`. 

163 

164 Notes 

165 ----- 

166 This interface imitates the function ``gbart`` from the R package `BART 

167 <https://cran.r-project.org/package=BART>`_, but with these differences: 

168 

169 - If `x_train` and `x_test` are matrices, they have one predictor per row 

170 instead of per column. 

171 - If `type` is not specified, it is determined solely based on the data type 

172 of `y_train`, and not on whether it contains only two unique values. 

173 - If ``usequants=False``, R BART switches to quantiles anyway if there are 

174 less predictor values than the required number of bins, while bartz 

175 always follows the specification. 

176 - The error variance parameter is called `lamda` instead of `lambda`. 

177 - `rm_const` is always `False`. 

178 - The default `numcut` is 255 instead of 100. 

179 - A lot of functionality is missing (e.g., variable selection). 

180 - There are some additional attributes, and some missing. 

181 - The trees have a maximum depth. 

182 

183 """ 

184 

185 def __init__( 1ab

186 self, 

187 x_train, 

188 y_train, 

189 *, 

190 x_test=None, 

191 type: Literal['wbart', 'pbart'] = 'wbart', 

192 usequants=False, 

193 sigest=None, 

194 sigdf=3, 

195 sigquant=0.9, 

196 k=2, 

197 power=2, 

198 base=0.95, 

199 lamda: FloatLike | None = None, 

200 tau_num: FloatLike | None = None, 

201 offset: FloatLike | None = None, 

202 w=None, 

203 ntree=200, 

204 numcut=255, 

205 ndpost=1000, 

206 nskip=100, 

207 keepevery=1, 

208 printevery=100, 

209 seed=0, 

210 maxdepth=6, 

211 init_kw=None, 

212 run_mcmc_kw=None, 

213 ): 

214 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab

215 y_train, _ = self._process_response_input(y_train) 1ab

216 self._check_same_length(x_train, y_train) 1ab

217 if w is not None: 1ab

218 w, _ = self._process_response_input(w) 1ab

219 self._check_same_length(x_train, w) 1ab

220 

221 y_train = self._process_type_settings(y_train, type, w) 1ab

222 # from here onwards, the type is determined by y_train.dtype == bool 

223 offset = self._process_offset_settings(y_train, offset) 1ab

224 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab

225 lamda, sigest = self._process_error_variance_settings( 1ab

226 x_train, y_train, sigest, sigdf, sigquant, lamda 

227 ) 

228 

229 splits, max_split = self._determine_splits(x_train, usequants, numcut) 1ab

230 x_train = self._bin_predictors(x_train, splits) 1ab

231 

232 mcmc_state = self._setup_mcmc( 1ab

233 x_train, 

234 y_train, 

235 offset, 

236 w, 

237 max_split, 

238 lamda, 

239 sigma_mu, 

240 sigdf, 

241 power, 

242 base, 

243 maxdepth, 

244 ntree, 

245 init_kw, 

246 ) 

247 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab

248 mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw 

249 ) 

250 

251 sigma = self._extract_sigma(main_trace) 1ab

252 first_sigma = self._extract_sigma(burnin_trace) 1ab

253 

254 self.offset = final_state.offset # from the state because of buffer donation 1ab

255 self.sigest = sigest 1ab

256 self.sigma = sigma 1ab

257 self.first_sigma = first_sigma 1ab

258 

259 self._x_train_fmt = x_train_fmt 1ab

260 self._splits = splits 1ab

261 self._main_trace = main_trace 1ab

262 self._mcmc_state = final_state 1ab

263 

264 if x_test is not None: 1ab

265 yhat_test = self.predict(x_test) 1ab

266 self.yhat_test = yhat_test 1ab

267 self.yhat_test_mean = yhat_test.mean(axis=0) 1ab

268 

269 @functools.cached_property 1ab

270 def yhat_train(self): 1ab

271 x_train = self._mcmc_state.X 1ab

272 return self._predict(self._main_trace, x_train) 1ab

273 

274 @functools.cached_property 1ab

275 def yhat_train_mean(self): 1ab

276 return self.yhat_train.mean(axis=0) 1ab

277 

278 def predict(self, x_test): 1ab

279 """ 

280 Compute the posterior mean at `x_test` for each MCMC iteration. 

281 

282 Parameters 

283 ---------- 

284 x_test : array (p, m) or DataFrame 

285 The test predictors. 

286 

287 Returns 

288 ------- 

289 yhat_test : array (ndpost, m) 

290 The conditional posterior mean at `x_test` for each MCMC iteration. 

291 

292 Raises 

293 ------ 

294 ValueError 

295 If `x_test` has a different format than `x_train`. 

296 """ 

297 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab

298 if x_test_fmt != self._x_train_fmt: 1ab

299 raise ValueError( 1ab

300 f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 

301 ) 

302 x_test = self._bin_predictors(x_test, self._splits) 1ab

303 return self._predict(self._main_trace, x_test) 1ab

304 

305 @staticmethod 1ab

306 def _process_predictor_input(x): 1ab

307 if hasattr(x, 'columns'): 1ab

308 fmt = dict(kind='dataframe', columns=x.columns) 1ab

309 x = x.to_numpy().T 1ab

310 else: 

311 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab

312 x = jnp.asarray(x) 1ab

313 assert x.ndim == 2 1ab

314 return x, fmt 1ab

315 

316 @staticmethod 1ab

317 def _process_response_input(y): 1ab

318 if hasattr(y, 'to_numpy'): 1ab

319 fmt = dict(kind='series', name=y.name) 1ab

320 y = y.to_numpy() 1ab

321 else: 

322 fmt = dict(kind='array') 1ab

323 y = jnp.asarray(y) 1ab

324 assert y.ndim == 1 1ab

325 return y, fmt 1ab

326 

327 @staticmethod 1ab

328 def _check_same_length(x1, x2): 1ab

329 get_length = lambda x: x.shape[-1] 1ab

330 assert get_length(x1) == get_length(x2) 1ab

331 

332 @staticmethod 1ab

333 def _process_error_variance_settings( 1ab

334 x_train, y_train, sigest, sigdf, sigquant, lamda 

335 ) -> tuple[Float32[Array, ''] | None, ...]: 

336 if y_train.dtype == bool: 1ab

337 if sigest is not None: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true1ab

338 raise ValueError('Let `sigest=None` for binary regression') 

339 if lamda is not None: 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true1ab

340 raise ValueError('Let `lamda=None` for binary regression') 

341 return None, None 1ab

342 elif lamda is not None: 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true1ab

343 if sigest is not None: 

344 raise ValueError('Let `sigest=None` if `lamda` is specified') 

345 return lamda, None 

346 else: 

347 if sigest is not None: 347 ↛ 348line 347 didn't jump to line 348 because the condition on line 347 was never true1ab

348 sigest2 = jnp.square(sigest) 

349 elif y_train.size < 2: 1ab

350 sigest2 = 1 1ab

351 elif y_train.size <= x_train.shape[0]: 1ab

352 sigest2 = jnp.var(y_train) 1ab

353 else: 

354 x_centered = x_train.T - x_train.mean(axis=1) 1ab

355 y_centered = y_train - y_train.mean() 1ab

356 # centering is equivalent to adding an intercept column 

357 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab

358 chisq = chisq.squeeze(0) 1ab

359 dof = len(y_train) - rank 1ab

360 sigest2 = chisq / dof 1ab

361 alpha = sigdf / 2 1ab

362 invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2 1ab

363 invchi2rid = invchi2 * sigdf 1ab

364 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab

365 

366 @staticmethod 1ab

367 def _process_type_settings(y_train, type, w): 1ab

368 match type: 1ab

369 case 'wbart': 1ab

370 if y_train.dtype != jnp.float32: 1ab

371 raise TypeError( 1a

372 'Continuous regression requires y_train.dtype=float32,' 

373 f' got {y_train.dtype=} instead.' 

374 ) 

375 case 'pbart': 375 ↛ 385line 375 didn't jump to line 385 because the pattern on line 375 always matched1ab

376 if w is not None: 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true1ab

377 raise ValueError( 

378 'Binary regression does not support weights, set `w=None`' 

379 ) 

380 if y_train.dtype != bool: 1ab

381 raise TypeError( 1a

382 'Binary regression requires y_train.dtype=bool,' 

383 f' got {y_train.dtype=} instead.' 

384 ) 

385 case _: 

386 raise ValueError(f'Invalid {type=}') 

387 

388 return y_train 1ab

389 

390 @staticmethod 1ab

391 def _process_offset_settings( 1ab

392 y_train: Float32[Array, 'n'] | Bool[Array, 'n'], 

393 offset: float | Float32[Any, ''] | None, 

394 ) -> Float32[Array, '']: 

395 if offset is not None: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true1ab

396 return jnp.asarray(offset) 

397 elif y_train.size < 1: 1ab

398 return jnp.array(0.0) 1ab

399 else: 

400 mean = y_train.mean() 1ab

401 

402 if y_train.dtype == bool: 1ab

403 return ndtri(mean) 1ab

404 else: 

405 return mean 1ab

406 

407 @staticmethod 1ab

408 def _process_leaf_sdev_settings( 1ab

409 y_train: Float32[Array, 'n'] | Bool[Array, 'n'], 

410 k: float, 

411 ntree: int, 

412 tau_num: FloatLike | None, 

413 ): 

414 if tau_num is None: 414 ↛ 422line 414 didn't jump to line 422 because the condition on line 414 was always true1ab

415 if y_train.dtype == bool: 1ab

416 tau_num = 3.0 1ab

417 elif y_train.size < 2: 1ab

418 tau_num = 1.0 1ab

419 else: 

420 tau_num = (y_train.max() - y_train.min()) / 2 1ab

421 

422 return tau_num / (k * math.sqrt(ntree)) 1ab

423 

424 @staticmethod 1ab

425 def _determine_splits(x_train, usequants, numcut): 1ab

426 if usequants: 1ab

427 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab

428 else: 

429 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab

430 

431 @staticmethod 1ab

432 def _bin_predictors(x, splits): 1ab

433 return prepcovars.bin_predictors(x, splits) 1ab

434 

435 @staticmethod 1ab

436 def _setup_mcmc( 1ab

437 x_train, 

438 y_train, 

439 offset, 

440 w, 

441 max_split, 

442 lamda, 

443 sigma_mu, 

444 sigdf, 

445 power, 

446 base, 

447 maxdepth, 

448 ntree, 

449 init_kw, 

450 ): 

451 depth = jnp.arange(maxdepth - 1) 1ab

452 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab

453 

454 if y_train.dtype == bool: 1ab

455 sigma2_alpha = None 1ab

456 sigma2_beta = None 1ab

457 else: 

458 sigma2_alpha = sigdf / 2 1ab

459 sigma2_beta = lamda * sigma2_alpha 1ab

460 

461 kw = dict( 1ab

462 X=x_train, 

463 # copy y_train because it's going to be donated in the mcmc loop 

464 y=jnp.array(y_train), 

465 offset=offset, 

466 error_scale=w, 

467 max_split=max_split, 

468 num_trees=ntree, 

469 p_nonterminal=p_nonterminal, 

470 sigma_mu2=jnp.square(sigma_mu), 

471 sigma2_alpha=sigma2_alpha, 

472 sigma2_beta=sigma2_beta, 

473 min_points_per_leaf=5, 

474 ) 

475 if init_kw is not None: 475 ↛ 477line 475 didn't jump to line 477 because the condition on line 475 was always true1ab

476 kw.update(init_kw) 1ab

477 return mcmcstep.init(**kw) 1ab

478 

479 @staticmethod 1ab

480 def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw): 1ab

481 if isinstance(seed, jax.Array) and jnp.issubdtype( 481 ↛ 487line 481 didn't jump to line 487 because the condition on line 481 was always true1ab

482 seed.dtype, jax.dtypes.prng_key 

483 ): 

484 key = seed.copy() 1ab

485 # copy because the inner loop in run_mcmc will donate the buffer 

486 else: 

487 key = jax.random.key(seed) 

488 

489 kw = dict( 1ab

490 n_burn=nskip, 

491 n_skip=keepevery, 

492 inner_loop_length=printevery, 

493 allow_overflow=True, 

494 ) 

495 if printevery is not None: 1ab

496 kw.update(mcmcloop.make_print_callbacks()) 1ab

497 if run_mcmc_kw is not None: 497 ↛ 498line 497 didn't jump to line 498 because the condition on line 497 was never true1ab

498 kw.update(run_mcmc_kw) 

499 

500 return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw) 1ab

501 

502 @staticmethod 1ab

503 def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None: 1ab

504 if trace['sigma2'] is None: 1ab

505 return None 1ab

506 else: 

507 return jnp.sqrt(trace['sigma2']) 1ab

508 

509 @staticmethod 1ab

510 def _predict(trace, x): 1ab

511 return mcmcloop.evaluate_trace(trace, x) 1ab

512 

513 def _show_tree(self, i_sample, i_tree, print_all=False): 1ab

514 from . import debug 

515 

516 trace = self._main_trace 

517 leaf_tree = trace['leaf_trees'][i_sample, i_tree] 

518 var_tree = trace['var_trees'][i_sample, i_tree] 

519 split_tree = trace['split_trees'][i_sample, i_tree] 

520 debug.print_tree(leaf_tree, var_tree, split_tree, print_all) 

521 

522 def _sigma_harmonic_mean(self, prior=False): 1ab

523 bart = self._mcmc_state 

524 if prior: 

525 alpha = bart['sigma2_alpha'] 

526 beta = bart['sigma2_beta'] 

527 else: 

528 resid = bart['resid'] 

529 alpha = bart['sigma2_alpha'] + resid.size / 2 

530 norm2 = jnp.dot( 

531 resid, resid, preferred_element_type=bart['sigma2_beta'].dtype 

532 ) 

533 beta = bart['sigma2_beta'] + norm2 / 2 

534 sigma2 = beta / alpha 

535 return jnp.sqrt(sigma2) 

536 

537 def _compare_resid(self): 1ab

538 bart = self._mcmc_state 1ab

539 resid1 = bart.resid 1ab

540 

541 trees = grove.evaluate_forest( 1ab

542 bart.X, 

543 bart.forest.leaf_trees, 

544 bart.forest.var_trees, 

545 bart.forest.split_trees, 

546 jnp.float32, # TODO remove these configurable dtypes around 

547 ) 

548 

549 if bart.z is not None: 1ab

550 ref = bart.z 1ab

551 else: 

552 ref = bart.y 1ab

553 resid2 = ref - (trees + bart.offset) 1ab

554 

555 return resid1, resid2 1ab

556 

557 def _avg_acc(self): 1ab

558 trace = self._main_trace 

559 

560 def acc(prefix): 

561 acc = trace[f'{prefix}_acc_count'] 

562 prop = trace[f'{prefix}_prop_count'] 

563 return acc.sum() / prop.sum() 

564 

565 return acc('grow'), acc('prune') 

566 

567 def _avg_prop(self): 1ab

568 trace = self._main_trace 

569 

570 def prop(prefix): 

571 return trace[f'{prefix}_prop_count'].sum() 

572 

573 pgrow = prop('grow') 

574 pprune = prop('prune') 

575 total = pgrow + pprune 

576 return pgrow / total, pprune / total 

577 

578 def _avg_move(self): 1ab

579 agrow, aprune = self._avg_acc() 

580 pgrow, pprune = self._avg_prop() 

581 return agrow * pgrow, aprune * pprune 

582 

583 def _depth_distr(self): 1ab

584 from . import debug 

585 

586 trace = self._main_trace 

587 split_trees = trace['split_trees'] 

588 return debug.trace_depth_distr(split_trees) 

589 

590 def _points_per_leaf_distr(self): 1ab

591 from . import debug 1ab

592 

593 return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X) 1ab

594 

595 def _check_trees(self): 1ab

596 from . import debug 1ab

597 

598 return debug.check_trace(self._main_trace, self._mcmc_state) 1ab

599 

600 def _tree_goes_bad(self): 1ab

601 bad = self._check_trees().astype(bool) 

602 bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)]) 

603 return bad & ~bad_before