Coverage for src/bartz/BART.py: 89%

279 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-07 22:47 +0000

1# bartz/src/bartz/BART.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement a class `gbart` that mimics the R BART package.""" 

26 

27import math 1ab

28from collections.abc import Sequence 1ab

29from functools import cached_property 1ab

30from typing import Any, Literal, Protocol 1ab

31 

32import jax 1ab

33import jax.numpy as jnp 1ab

34from equinox import Module, field 1ab

35from jax.scipy.special import ndtr 1ab

36from jaxtyping import ( 1ab

37 Array, 

38 Bool, 

39 Float, 

40 Float32, 

41 Int32, 

42 Integer, 

43 Key, 

44 Real, 

45 Shaped, 

46 UInt, 

47) 

48from numpy import ndarray 1ab

49 

50from bartz import mcmcloop, mcmcstep, prepcovars 1ab

51from bartz.jaxext.scipy.special import ndtri 1ab

52from bartz.jaxext.scipy.stats import invgamma 1ab

53 

54FloatLike = float | Float[Any, ''] 1ab

55 

56 

57class DataFrame(Protocol): 1ab

58 """DataFrame duck-type for `gbart`. 

59 

60 Attributes 

61 ---------- 

62 columns : Sequence[str] 

63 The names of the columns. 

64 """ 

65 

66 columns: Sequence[str] 1ab

67 

68 def to_numpy(self) -> ndarray: 1ab

69 """Convert the dataframe to a 2d numpy array with columns on the second axis.""" 

70 ... 

71 

72 

73class Series(Protocol): 1ab

74 """Series duck-type for `gbart`. 

75 

76 Attributes 

77 ---------- 

78 name : str | None 

79 The name of the series. 

80 """ 

81 

82 name: str | None 1ab

83 

84 def to_numpy(self) -> ndarray: 1ab

85 """Convert the series to a 1d numpy array.""" 

86 ... 

87 

88 

89class gbart(Module): 1ab

90 R""" 

91 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

92 

93 Regress `y_train` on `x_train` with a latent mean function represented as 

94 a sum of decision trees. The inference is carried out by sampling the 

95 posterior distribution of the tree ensemble with an MCMC. 

96 

97 Parameters 

98 ---------- 

99 x_train 

100 The training predictors. 

101 y_train 

102 The training responses. 

103 x_test 

104 The test predictors. 

105 type 

106 The type of regression. 'wbart' for continuous regression, 'pbart' for 

107 binary regression with probit link. 

108 sparse 

109 Whether to activate variable selection on the predictors as done in 

110 [1]_. 

111 theta 

112 a 

113 b 

114 rho 

115 Hyperparameters of the sparsity prior used for variable selection. 

116 

117 The prior distribution on the choice of predictor for each decision rule 

118 is 

119 

120 .. math:: 

121 (s_1, \ldots, s_p) \sim 

122 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

123 

124 If `theta` is not specified, it's a priori distributed according to 

125 

126 .. math:: 

127 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

128 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

129 

130 If not specified, `rho` is set to the number of predictors p. To tune 

131 the prior, consider setting a lower `rho` to prefer more sparsity. 

132 If setting `theta` directly, it should be in the ballpark of p or lower 

133 as well. 

134 xinfo 

135 A matrix with the cutpoins to use to bin each predictor. If not 

136 specified, it is generated automatically according to `usequants` and 

137 `numcut`. 

138 

139 Each row shall contain a sorted list of cutpoints for a predictor. If 

140 there are less cutpoints than the number of columns in the matrix, 

141 fill the remaining cells with NaN. 

142 

143 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

144 usequants 

145 Whether to use predictors quantiles instead of a uniform grid to bin 

146 predictors. Ignored if `xinfo` is specified. 

147 rm_const 

148 How to treat predictors with no associated decision rules (i.e., there 

149 are no available cutpoints for that predictor). If `True` (default), 

150 they are ignored. If `False`, an error is raised if there are any. If 

151 `None`, no check is performed, and the output of the MCMC may not make 

152 sense if there are predictors without cutpoints. The option `None` is 

153 provided only to allow jax tracing. 

154 sigest 

155 An estimate of the residual standard deviation on `y_train`, used to set 

156 `lamda`. If not specified, it is estimated by linear regression (with 

157 intercept, and without taking into account `w`). If `y_train` has less 

158 than two elements, it is set to 1. If n <= p, it is set to the standard 

159 deviation of `y_train`. Ignored if `lamda` is specified. 

160 sigdf 

161 The degrees of freedom of the scaled inverse-chisquared prior on the 

162 noise variance. 

163 sigquant 

164 The quantile of the prior on the noise variance that shall match 

165 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

166 k 

167 The inverse scale of the prior standard deviation on the latent mean 

168 function, relative to half the observed range of `y_train`. If `y_train` 

169 has less than two elements, `k` is ignored and the scale is set to 1. 

170 power 

171 base 

172 Parameters of the prior on tree node generation. The probability that a 

173 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

174 power``. 

175 lamda 

176 The prior harmonic mean of the error variance. (The harmonic mean of x 

177 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

178 `sigquant`. 

179 tau_num 

180 The numerator in the expression that determines the prior standard 

181 deviation of leaves. If not specified, default to ``(max(y_train) - 

182 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

183 continuous regression, and 3 for binary regression. 

184 offset 

185 The prior mean of the latent mean function. If not specified, it is set 

186 to the mean of `y_train` for continuous regression, and to 

187 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

188 `offset` is set to 0. With binary regression, if `y_train` is all 

189 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

190 ``Phi^-1(n/(n+1))``, respectively. 

191 w 

192 Coefficients that rescale the error standard deviation on each 

193 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

194 datapoints. Note: `w` is ignored in the automatic determination of 

195 `sigest`, so either the weights should be O(1), or `sigest` should be 

196 specified by the user. 

197 ntree 

198 The number of trees used to represent the latent mean function. By 

199 default 200 for continuous regression and 50 for binary regression. 

200 numcut 

201 If `usequants` is `False`: the exact number of cutpoints used to bin the 

202 predictors, ranging between the minimum and maximum observed values 

203 (excluded). 

204 

205 If `usequants` is `True`: the maximum number of cutpoints to use for 

206 binning the predictors. Each predictor is binned such that its 

207 distribution in `x_train` is approximately uniform across bins. The 

208 number of bins is at most the number of unique values appearing in 

209 `x_train`, or ``numcut + 1``. 

210 

211 Before running the algorithm, the predictors are compressed to the 

212 smallest integer type that fits the bin indices, so `numcut` is best set 

213 to the maximum value of an unsigned integer type, like 255. 

214 

215 Ignored if `xinfo` is specified. 

216 ndpost 

217 The number of MCMC samples to save, after burn-in. 

218 nskip 

219 The number of initial MCMC samples to discard as burn-in. 

220 keepevery 

221 The thinning factor for the MCMC samples, after burn-in. By default, 1 

222 for continuous regression and 10 for binary regression. 

223 printevery 

224 The number of iterations (including thinned-away ones) between each log 

225 line. Set to `None` to disable logging. 

226 

227 `printevery` has a few unexpected side effects. On cpu, interrupting 

228 with ^C halts the MCMC only on the next log. And the total number of 

229 iterations is a multiple of `printevery`, so if ``nskip + keepevery * 

230 ndpost`` is not a multiple of `printevery`, some of the last iterations 

231 will not be saved. 

232 seed 

233 The seed for the random number generator. 

234 maxdepth 

235 The maximum depth of the trees. This is 1-based, so with the default 

236 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

237 init_kw 

238 Additional arguments passed to `bartz.mcmcstep.init`. 

239 run_mcmc_kw 

240 Additional arguments passed to `bartz.mcmcloop.run_mcmc`. 

241 

242 Attributes 

243 ---------- 

244 offset : Float32[Array, ''] 

245 The prior mean of the latent mean function. 

246 sigest : Float32[Array, ''] | None 

247 The estimated standard deviation of the error used to set `lamda`. 

248 yhat_test : Float32[Array, 'ndpost m'] | None 

249 The conditional posterior mean at `x_test` for each MCMC iteration. 

250 

251 Notes 

252 ----- 

253 This interface imitates the function ``gbart`` from the R package `BART 

254 <https://cran.r-project.org/package=BART>`_, but with these differences: 

255 

256 - If `x_train` and `x_test` are matrices, they have one predictor per row 

257 instead of per column. 

258 - If ``usequants=False``, R BART switches to quantiles anyway if there are 

259 less predictor values than the required number of bins, while bartz 

260 always follows the specification. 

261 - Some functionality is missing. 

262 - The error variance parameter is called `lamda` instead of `lambda`. 

263 - There are some additional attributes, and some missing. 

264 - The trees have a maximum depth. 

265 - `rm_const` refers to predictors without decision rules instead of 

266 predictors that are constant in `x_train`. 

267 - If `rm_const=True` and some variables are dropped, the predictors 

268 matrix/dataframe passed to `predict` should still include them. 

269 

270 References 

271 ---------- 

272 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for 

273 High-Dimensional Prediction and Variable Selection”. In: Journal of the 

274 American Statistical Association 113.522, pp. 626-636. 

275 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

276 Bayesian additive regression trees," The Annals of Applied Statistics, 

277 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

278 """ 

279 

280 _main_trace: mcmcloop.MainTrace 1ab

281 _burnin_trace: mcmcloop.BurninTrace 1ab

282 _mcmc_state: mcmcstep.State 1ab

283 _splits: Real[Array, 'p max_num_splits'] 1ab

284 _x_train_fmt: Any = field(static=True) 1ab

285 

286 ndpost: int = field(static=True) 1ab

287 offset: Float32[Array, ''] 1ab

288 sigest: Float32[Array, ''] | None = None 1ab

289 yhat_test: Float32[Array, 'ndpost m'] | None = None 1ab

290 

291 def __init__( 1ab

292 self, 

293 x_train: Real[Array, 'p n'] | DataFrame, 

294 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, 

295 *, 

296 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

297 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

298 sparse: bool = False, 

299 theta: FloatLike | None = None, 

300 a: FloatLike = 0.5, 

301 b: FloatLike = 1.0, 

302 rho: FloatLike | None = None, 

303 xinfo: Float[Array, 'p n'] | None = None, 

304 usequants: bool = False, 

305 rm_const: bool | None = True, 

306 sigest: FloatLike | None = None, 

307 sigdf: FloatLike = 3.0, 

308 sigquant: FloatLike = 0.9, 

309 k: FloatLike = 2.0, 

310 power: FloatLike = 2.0, 

311 base: FloatLike = 0.95, 

312 lamda: FloatLike | None = None, 

313 tau_num: FloatLike | None = None, 

314 offset: FloatLike | None = None, 

315 w: Float[Array, ' n'] | None = None, 

316 ntree: int | None = None, 

317 numcut: int = 100, 

318 ndpost: int = 1000, 

319 nskip: int = 100, 

320 keepevery: int | None = None, 

321 printevery: int | None = 100, 

322 seed: int | Key[Array, ''] = 0, 

323 maxdepth: int = 6, 

324 init_kw: dict | None = None, 

325 run_mcmc_kw: dict | None = None, 

326 ): 

327 # check data and put it in the right format 

328 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab

329 y_train = self._process_response_input(y_train) 1ab

330 self._check_same_length(x_train, y_train) 1ab

331 if w is not None: 1ab

332 w = self._process_response_input(w) 1ab

333 self._check_same_length(x_train, w) 1ab

334 

335 # check data types are correct for continuous/binary regression 

336 self._check_type_settings(y_train, type, w) 1ab

337 # from here onwards, the type is determined by y_train.dtype == bool 

338 

339 # set defaults that depend on type of regression 

340 if ntree is None: 1ab

341 ntree = 50 if y_train.dtype == bool else 200 1ab

342 if keepevery is None: 1ab

343 keepevery = 10 if y_train.dtype == bool else 1 1ab

344 

345 # process sparsity settings 

346 theta, a, b, rho = self._process_sparsity_settings( 1ab

347 x_train, sparse, theta, a, b, rho 

348 ) 

349 

350 # process "standardization" settings 

351 offset = self._process_offset_settings(y_train, offset) 1ab

352 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab

353 lamda, sigest = self._process_error_variance_settings( 1ab

354 x_train, y_train, sigest, sigdf, sigquant, lamda 

355 ) 

356 

357 # determine splits 

358 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1ab

359 x_train = self._bin_predictors(x_train, splits) 1ab

360 

361 # setup and run mcmc 

362 initial_state = self._setup_mcmc( 1ab

363 x_train, 

364 y_train, 

365 offset, 

366 w, 

367 max_split, 

368 lamda, 

369 sigma_mu, 

370 sigdf, 

371 power, 

372 base, 

373 maxdepth, 

374 ntree, 

375 init_kw, 

376 rm_const, 

377 theta, 

378 a, 

379 b, 

380 rho, 

381 ) 

382 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab

383 initial_state, 

384 ndpost, 

385 nskip, 

386 keepevery, 

387 printevery, 

388 seed, 

389 run_mcmc_kw, 

390 sparse, 

391 ) 

392 

393 # set public attributes 

394 self.offset = final_state.offset # from the state because of buffer donation 1ab

395 self.ndpost = ndpost 1ab

396 self.sigest = sigest 1ab

397 

398 # set private attributes 

399 self._main_trace = main_trace 1ab

400 self._burnin_trace = burnin_trace 1ab

401 self._mcmc_state = final_state 1ab

402 self._splits = splits 1ab

403 self._x_train_fmt = x_train_fmt 1ab

404 

405 # predict at test points 

406 if x_test is not None: 1ab

407 self.yhat_test = self.predict(x_test) 1ab

408 

409 @cached_property 1ab

410 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 1ab

411 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

412 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1ab

413 return None 1ab

414 else: 

415 return ndtr(self.yhat_test) 1ab

416 

417 @cached_property 1ab

418 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 1ab

419 """The marginal posterior probability of y being True at `x_test`.""" 

420 if self.prob_test is None: 1ab

421 return None 1ab

422 else: 

423 return self.prob_test.mean(axis=0) 1ab

424 

425 @cached_property 1ab

426 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 1ab

427 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

428 if self._mcmc_state.y.dtype == bool: 1ab

429 return ndtr(self.yhat_train) 1ab

430 else: 

431 return None 1ab

432 

433 @cached_property 1ab

434 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 1ab

435 """The marginal posterior probability of y being True at `x_train`.""" 

436 if self.prob_train is None: 1ab

437 return None 1ab

438 else: 

439 return self.prob_train.mean(axis=0) 1ab

440 

441 @cached_property 1ab

442 def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None: 1ab

443 """The standard deviation of the error, including burn-in samples.""" 

444 if self._burnin_trace.sigma2 is None: 1ab

445 return None 1ab

446 else: 

447 assert self._main_trace.sigma2 is not None 1ab

448 return jnp.sqrt( 1ab

449 jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2]) 

450 ) 

451 

452 @cached_property 1ab

453 def sigma_mean(self) -> Float32[Array, ''] | None: 1ab

454 """The mean of `sigma`, only over the post-burnin samples.""" 

455 if self.sigma is None: 1ab

456 return None 1ab

457 else: 

458 return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0) 1ab

459 

460 @cached_property 1ab

461 def varcount(self) -> Int32[Array, 'ndpost p']: 1ab

462 """Histogram of predictor usage for decision rules in the trees.""" 

463 return mcmcloop.compute_varcount( 1ab

464 self._mcmc_state.forest.max_split.size, self._main_trace 

465 ) 

466 

467 @cached_property 1ab

468 def varcount_mean(self) -> Float32[Array, ' p']: 1ab

469 """Average of `varcount` across MCMC iterations.""" 

470 return self.varcount.mean(axis=0) 1ab

471 

472 @cached_property 1ab

473 def varprob(self) -> Float32[Array, 'ndpost p']: 1ab

474 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

475 varprob = self._main_trace.varprob 1ab

476 if varprob is None: 1ab

477 max_split = self._mcmc_state.forest.max_split 1ab

478 p = max_split.size 1ab

479 peff = jnp.count_nonzero(max_split) 1ab

480 varprob = jnp.where(max_split, 1 / peff, 0) 1ab

481 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1ab

482 return varprob 1ab

483 

484 @cached_property 1ab

485 def varprob_mean(self) -> Float32[Array, ' p']: 1ab

486 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

487 return self.varprob.mean(axis=0) 1ab

488 

489 @cached_property 1ab

490 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 1ab

491 """The marginal posterior mean at `x_test`. 

492 

493 Not defined with binary regression because it's error-prone, typically 

494 the right thing to consider would be `prob_test_mean`. 

495 """ 

496 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1ab

497 return None 1ab

498 else: 

499 return self.yhat_test.mean(axis=0) 1ab

500 

501 @cached_property 1ab

502 def yhat_train(self) -> Float32[Array, 'ndpost n']: 1ab

503 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

504 x_train = self._mcmc_state.X 1ab

505 return self._predict(x_train) 1ab

506 

507 @cached_property 1ab

508 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 1ab

509 """The marginal posterior mean at `x_train`. 

510 

511 Not defined with binary regression because it's error-prone, typically 

512 the right thing to consider would be `prob_train_mean`. 

513 """ 

514 if self._mcmc_state.y.dtype == bool: 1ab

515 return None 1ab

516 else: 

517 return self.yhat_train.mean(axis=0) 1ab

518 

519 def predict( 1ab

520 self, x_test: Real[Array, 'p m'] | DataFrame 

521 ) -> Float32[Array, 'ndpost m']: 

522 """ 

523 Compute the posterior mean at `x_test` for each MCMC iteration. 

524 

525 Parameters 

526 ---------- 

527 x_test 

528 The test predictors. 

529 

530 Returns 

531 ------- 

532 The conditional posterior mean at `x_test` for each MCMC iteration. 

533 

534 Raises 

535 ------ 

536 ValueError 

537 If `x_test` has a different format than `x_train`. 

538 """ 

539 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab

540 if x_test_fmt != self._x_train_fmt: 1ab

541 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1ab

542 raise ValueError(msg) 1ab

543 x_test = self._bin_predictors(x_test, self._splits) 1ab

544 return self._predict(x_test) 1ab

545 

546 @staticmethod 1ab

547 def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: 1ab

548 if hasattr(x, 'columns'): 1ab

549 fmt = dict(kind='dataframe', columns=x.columns) 1ab

550 x = x.to_numpy().T 1ab

551 else: 

552 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab

553 x = jnp.asarray(x) 1ab

554 assert x.ndim == 2 1ab

555 return x, fmt 1ab

556 

557 @staticmethod 1ab

558 def _process_response_input(y) -> Shaped[Array, ' n']: 1ab

559 if hasattr(y, 'to_numpy'): 1ab

560 y = y.to_numpy() 1ab

561 y = jnp.asarray(y) 1ab

562 assert y.ndim == 1 1ab

563 return y 1ab

564 

565 @staticmethod 1ab

566 def _check_same_length(x1, x2): 1ab

567 get_length = lambda x: x.shape[-1] 1ab

568 assert get_length(x1) == get_length(x2) 1ab

569 

570 @staticmethod 1ab

571 def _process_error_variance_settings( 1ab

572 x_train, y_train, sigest, sigdf, sigquant, lamda 

573 ) -> tuple[Float32[Array, ''] | None, ...]: 

574 if y_train.dtype == bool: 1ab

575 if sigest is not None: 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true1ab

576 msg = 'Let `sigest=None` for binary regression' 

577 raise ValueError(msg) 

578 if lamda is not None: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true1ab

579 msg = 'Let `lamda=None` for binary regression' 

580 raise ValueError(msg) 

581 return None, None 1ab

582 elif lamda is not None: 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true1ab

583 if sigest is not None: 

584 msg = 'Let `sigest=None` if `lamda` is specified' 

585 raise ValueError(msg) 

586 return lamda, None 

587 else: 

588 if sigest is not None: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true1ab

589 sigest2 = jnp.square(sigest) 

590 elif y_train.size < 2: 1ab

591 sigest2 = 1 1ab

592 elif y_train.size <= x_train.shape[0]: 1ab

593 sigest2 = jnp.var(y_train) 1ab

594 else: 

595 x_centered = x_train.T - x_train.mean(axis=1) 1ab

596 y_centered = y_train - y_train.mean() 1ab

597 # centering is equivalent to adding an intercept column 

598 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab

599 chisq = chisq.squeeze(0) 1ab

600 dof = len(y_train) - rank 1ab

601 sigest2 = chisq / dof 1ab

602 alpha = sigdf / 2 1ab

603 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1ab

604 invchi2rid = invchi2 * sigdf 1ab

605 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab

606 

607 @staticmethod 1ab

608 def _check_type_settings(y_train, type, w): # noqa: A002 1ab

609 match type: 1ab

610 case 'wbart': 1ab

611 if y_train.dtype != jnp.float32: 611 ↛ 612line 611 didn't jump to line 612 because the condition on line 611 was never true1ab

612 msg = ( 

613 'Continuous regression requires y_train.dtype=float32,' 

614 f' got {y_train.dtype=} instead.' 

615 ) 

616 raise TypeError(msg) 1a

617 case 'pbart': 617 ↛ 627line 617 didn't jump to line 627 because the pattern on line 617 always matched1ab

618 if w is not None: 618 ↛ 619line 618 didn't jump to line 619 because the condition on line 618 was never true1ab

619 msg = 'Binary regression does not support weights, set `w=None`' 

620 raise ValueError(msg) 

621 if y_train.dtype != bool: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true1ab

622 msg = ( 

623 'Binary regression requires y_train.dtype=bool,' 

624 f' got {y_train.dtype=} instead.' 

625 ) 

626 raise TypeError(msg) 1a

627 case _: 

628 msg = f'Invalid {type=}' 

629 raise ValueError(msg) 

630 

631 @staticmethod 1ab

632 def _process_sparsity_settings( 1ab

633 x_train: Real[Array, 'p n'], 

634 sparse: bool, 

635 theta: FloatLike | None, 

636 a: FloatLike, 

637 b: FloatLike, 

638 rho: FloatLike | None, 

639 ) -> ( 

640 tuple[None, None, None, None] 

641 | tuple[FloatLike, None, None, None] 

642 | tuple[None, FloatLike, FloatLike, FloatLike] 

643 ): 

644 if not sparse: 1ab

645 return None, None, None, None 1ab

646 elif theta is not None: 1ab

647 return theta, None, None, None 1ab

648 else: 

649 if rho is None: 649 ↛ 652line 649 didn't jump to line 652 because the condition on line 649 was always true1ab

650 p, _ = x_train.shape 1ab

651 rho = float(p) 1ab

652 return None, a, b, rho 1ab

653 

654 @staticmethod 1ab

655 def _process_offset_settings( 1ab

656 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

657 offset: float | Float32[Any, ''] | None, 

658 ) -> Float32[Array, '']: 

659 if offset is not None: 659 ↛ 660line 659 didn't jump to line 660 because the condition on line 659 was never true1ab

660 return jnp.asarray(offset) 

661 elif y_train.size < 1: 1ab

662 return jnp.array(0.0) 1ab

663 else: 

664 mean = y_train.mean() 1ab

665 

666 if y_train.dtype == bool: 1ab

667 bound = 1 / (1 + y_train.size) 1ab

668 mean = jnp.clip(mean, bound, 1 - bound) 1ab

669 return ndtri(mean) 1ab

670 else: 

671 return mean 1ab

672 

673 @staticmethod 1ab

674 def _process_leaf_sdev_settings( 1ab

675 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

676 k: float, 

677 ntree: int, 

678 tau_num: FloatLike | None, 

679 ): 

680 if tau_num is None: 680 ↛ 688line 680 didn't jump to line 688 because the condition on line 680 was always true1ab

681 if y_train.dtype == bool: 1ab

682 tau_num = 3.0 1ab

683 elif y_train.size < 2: 1ab

684 tau_num = 1.0 1ab

685 else: 

686 tau_num = (y_train.max() - y_train.min()) / 2 1ab

687 

688 return tau_num / (k * math.sqrt(ntree)) 1ab

689 

690 @staticmethod 1ab

691 def _determine_splits( 1ab

692 x_train: Real[Array, 'p n'], 

693 usequants: bool, 

694 numcut: int, 

695 xinfo: Float[Array, 'p n'] | None, 

696 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

697 if xinfo is not None: 1ab

698 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1ab

699 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1ab

700 raise ValueError(msg) 1ab

701 return prepcovars.parse_xinfo(xinfo) 1ab

702 elif usequants: 1ab

703 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab

704 else: 

705 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab

706 

707 @staticmethod 1ab

708 def _bin_predictors(x, splits) -> UInt[Array, 'p n']: 1ab

709 return prepcovars.bin_predictors(x, splits) 1ab

710 

711 @staticmethod 1ab

712 def _setup_mcmc( 1ab

713 x_train: Real[Array, 'p n'], 

714 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

715 offset: Float32[Array, ''], 

716 w: Float[Array, ' n'] | None, 

717 max_split: UInt[Array, ' p'], 

718 lamda: Float32[Array, ''] | None, 

719 sigma_mu: FloatLike, 

720 sigdf: FloatLike, 

721 power: FloatLike, 

722 base: FloatLike, 

723 maxdepth: int, 

724 ntree: int, 

725 init_kw: dict[str, Any] | None, 

726 rm_const: bool | None, 

727 theta: FloatLike | None, 

728 a: FloatLike | None, 

729 b: FloatLike | None, 

730 rho: FloatLike | None, 

731 ): 

732 depth = jnp.arange(maxdepth - 1) 1ab

733 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab

734 

735 if y_train.dtype == bool: 1ab

736 sigma2_alpha = None 1ab

737 sigma2_beta = None 1ab

738 else: 

739 sigma2_alpha = sigdf / 2 1ab

740 sigma2_beta = lamda * sigma2_alpha 1ab

741 

742 kw = dict( 1ab

743 X=x_train, 

744 # copy y_train because it's going to be donated in the mcmc loop 

745 y=jnp.array(y_train), 

746 offset=offset, 

747 error_scale=w, 

748 max_split=max_split, 

749 num_trees=ntree, 

750 p_nonterminal=p_nonterminal, 

751 sigma_mu2=jnp.square(sigma_mu), 

752 sigma2_alpha=sigma2_alpha, 

753 sigma2_beta=sigma2_beta, 

754 min_points_per_decision_node=10, 

755 min_points_per_leaf=5, 

756 theta=theta, 

757 a=a, 

758 b=b, 

759 rho=rho, 

760 ) 

761 

762 if rm_const is None: 1ab

763 kw.update(filter_splitless_vars=False) 1ab

764 elif rm_const: 764 ↛ 767line 764 didn't jump to line 767 because the condition on line 764 was always true1ab

765 kw.update(filter_splitless_vars=True) 1ab

766 else: 

767 n_empty = jnp.count_nonzero(max_split == 0) 

768 if n_empty: 

769 msg = f'There are {n_empty}/{max_split.size} predictors without decision rules' 

770 raise ValueError(msg) 

771 kw.update(filter_splitless_vars=False) 

772 

773 if init_kw is not None: 1ab

774 kw.update(init_kw) 1ab

775 

776 return mcmcstep.init(**kw) 1ab

777 

778 @staticmethod 1ab

779 def _run_mcmc( 1ab

780 mcmc_state: mcmcstep.State, 

781 ndpost: int, 

782 nskip: int, 

783 keepevery: int, 

784 printevery: int | None, 

785 seed: int | Integer[Array, ''] | Key[Array, ''], 

786 run_mcmc_kw: dict | None, 

787 sparse: bool, 

788 ): 

789 # prepare random generator seed 

790 if isinstance(seed, jax.Array) and jnp.issubdtype( 1ab

791 seed.dtype, jax.dtypes.prng_key 

792 ): 

793 key = seed.copy() 1ab

794 # copy because the inner loop in run_mcmc will donate the buffer 

795 else: 

796 key = jax.random.key(seed) 1ab

797 

798 # prepare arguments 

799 kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1ab

800 kw.update( 1ab

801 mcmcloop.make_default_callback( 

802 dot_every=None if printevery is None or printevery == 1 else 1, 

803 report_every=printevery, 

804 sparse_on_at=nskip // 2 if sparse else None, 

805 ) 

806 ) 

807 if run_mcmc_kw is not None: 1ab

808 kw.update(run_mcmc_kw) 1ab

809 

810 return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw) 1ab

811 

812 def _predict(self, x): 1ab

813 return mcmcloop.evaluate_trace(self._main_trace, x) 1ab