Coverage for src/bartz/BART.py: 91%

336 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-07-31 16:09 +0000

1# bartz/src/bartz/BART.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART package.""" 

26 

27import math 1ab

28from collections.abc import Sequence 1ab

29from functools import cached_property, partial 1ab

30from typing import Any, Literal, Protocol 1ab

31 

32import jax 1ab

33import jax.numpy as jnp 1ab

34from equinox import Module, field 1ab

35from jax.scipy.special import ndtr 1ab

36from jax.tree import map_with_path 1ab

37from jaxtyping import ( 1ab

38 Array, 

39 Bool, 

40 Float, 

41 Float32, 

42 Int32, 

43 Integer, 

44 Key, 

45 Real, 

46 Shaped, 

47 UInt, 

48) 

49from numpy import ndarray 1ab

50 

51from bartz import mcmcloop, mcmcstep, prepcovars 1ab

52from bartz.jaxext.scipy.special import ndtri 1ab

53from bartz.jaxext.scipy.stats import invgamma 1ab

54 

55FloatLike = float | Float[Any, ''] 1ab

56 

57 

58class DataFrame(Protocol): 1ab

59 """DataFrame duck-type for `mc_gbart`. 

60 

61 Attributes 

62 ---------- 

63 columns : Sequence[str] 

64 The names of the columns. 

65 """ 

66 

67 columns: Sequence[str] 1ab

68 

69 def to_numpy(self) -> ndarray: 1ab

70 """Convert the dataframe to a 2d numpy array with columns on the second axis.""" 

71 ... 

72 

73 

74class Series(Protocol): 1ab

75 """Series duck-type for `mc_gbart`. 

76 

77 Attributes 

78 ---------- 

79 name : str | None 

80 The name of the series. 

81 """ 

82 

83 name: str | None 1ab

84 

85 def to_numpy(self) -> ndarray: 1ab

86 """Convert the series to a 1d numpy array.""" 

87 ... 

88 

89 

90class mc_gbart(Module): 1ab

91 R""" 

92 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

93 

94 Regress `y_train` on `x_train` with a latent mean function represented as 

95 a sum of decision trees. The inference is carried out by sampling the 

96 posterior distribution of the tree ensemble with an MCMC. 

97 

98 Parameters 

99 ---------- 

100 x_train 

101 The training predictors. 

102 y_train 

103 The training responses. 

104 x_test 

105 The test predictors. 

106 type 

107 The type of regression. 'wbart' for continuous regression, 'pbart' for 

108 binary regression with probit link. 

109 sparse 

110 Whether to activate variable selection on the predictors as done in 

111 [1]_. 

112 theta 

113 a 

114 b 

115 rho 

116 Hyperparameters of the sparsity prior used for variable selection. 

117 

118 The prior distribution on the choice of predictor for each decision rule 

119 is 

120 

121 .. math:: 

122 (s_1, \ldots, s_p) \sim 

123 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

124 

125 If `theta` is not specified, it's a priori distributed according to 

126 

127 .. math:: 

128 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

129 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

130 

131 If not specified, `rho` is set to the number of predictors p. To tune 

132 the prior, consider setting a lower `rho` to prefer more sparsity. 

133 If setting `theta` directly, it should be in the ballpark of p or lower 

134 as well. 

135 xinfo 

136 A matrix with the cutpoins to use to bin each predictor. If not 

137 specified, it is generated automatically according to `usequants` and 

138 `numcut`. 

139 

140 Each row shall contain a sorted list of cutpoints for a predictor. If 

141 there are less cutpoints than the number of columns in the matrix, 

142 fill the remaining cells with NaN. 

143 

144 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

145 usequants 

146 Whether to use predictors quantiles instead of a uniform grid to bin 

147 predictors. Ignored if `xinfo` is specified. 

148 rm_const 

149 How to treat predictors with no associated decision rules (i.e., there 

150 are no available cutpoints for that predictor). If `True` (default), 

151 they are ignored. If `False`, an error is raised if there are any. If 

152 `None`, no check is performed, and the output of the MCMC may not make 

153 sense if there are predictors without cutpoints. The option `None` is 

154 provided only to allow jax tracing. 

155 sigest 

156 An estimate of the residual standard deviation on `y_train`, used to set 

157 `lamda`. If not specified, it is estimated by linear regression (with 

158 intercept, and without taking into account `w`). If `y_train` has less 

159 than two elements, it is set to 1. If n <= p, it is set to the standard 

160 deviation of `y_train`. Ignored if `lamda` is specified. 

161 sigdf 

162 The degrees of freedom of the scaled inverse-chisquared prior on the 

163 noise variance. 

164 sigquant 

165 The quantile of the prior on the noise variance that shall match 

166 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

167 k 

168 The inverse scale of the prior standard deviation on the latent mean 

169 function, relative to half the observed range of `y_train`. If `y_train` 

170 has less than two elements, `k` is ignored and the scale is set to 1. 

171 power 

172 base 

173 Parameters of the prior on tree node generation. The probability that a 

174 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

175 power``. 

176 lamda 

177 The prior harmonic mean of the error variance. (The harmonic mean of x 

178 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

179 `sigquant`. 

180 tau_num 

181 The numerator in the expression that determines the prior standard 

182 deviation of leaves. If not specified, default to ``(max(y_train) - 

183 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

184 continuous regression, and 3 for binary regression. 

185 offset 

186 The prior mean of the latent mean function. If not specified, it is set 

187 to the mean of `y_train` for continuous regression, and to 

188 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

189 `offset` is set to 0. With binary regression, if `y_train` is all 

190 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

191 ``Phi^-1(n/(n+1))``, respectively. 

192 w 

193 Coefficients that rescale the error standard deviation on each 

194 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

195 datapoints. Note: `w` is ignored in the automatic determination of 

196 `sigest`, so either the weights should be O(1), or `sigest` should be 

197 specified by the user. 

198 ntree 

199 The number of trees used to represent the latent mean function. By 

200 default 200 for continuous regression and 50 for binary regression. 

201 numcut 

202 If `usequants` is `False`: the exact number of cutpoints used to bin the 

203 predictors, ranging between the minimum and maximum observed values 

204 (excluded). 

205 

206 If `usequants` is `True`: the maximum number of cutpoints to use for 

207 binning the predictors. Each predictor is binned such that its 

208 distribution in `x_train` is approximately uniform across bins. The 

209 number of bins is at most the number of unique values appearing in 

210 `x_train`, or ``numcut + 1``. 

211 

212 Before running the algorithm, the predictors are compressed to the 

213 smallest integer type that fits the bin indices, so `numcut` is best set 

214 to the maximum value of an unsigned integer type, like 255. 

215 

216 Ignored if `xinfo` is specified. 

217 ndpost 

218 The number of MCMC samples to save, after burn-in. `ndpost` is the 

219 total number of samples across all chains. `ndpost` is rounded up to the 

220 first multiple of `mc_cores`. 

221 nskip 

222 The number of initial MCMC samples to discard as burn-in. This number 

223 of samples is discarded from each chain. 

224 keepevery 

225 The thinning factor for the MCMC samples, after burn-in. By default, 1 

226 for continuous regression and 10 for binary regression. 

227 printevery 

228 The number of iterations (including thinned-away ones) between each log 

229 line. Set to `None` to disable logging. 

230 

231 `printevery` has a few unexpected side effects. On cpu, interrupting 

232 with ^C halts the MCMC only on the next log. And the total number of 

233 iterations is a multiple of `printevery`, so if ``nskip + keepevery * 

234 ndpost`` is not a multiple of `printevery`, some of the last iterations 

235 will not be saved. 

236 mc_cores 

237 The number of independent MCMC chains. 

238 seed 

239 The seed for the random number generator. 

240 maxdepth 

241 The maximum depth of the trees. This is 1-based, so with the default 

242 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

243 init_kw 

244 Additional arguments passed to `bartz.mcmcstep.init`. 

245 run_mcmc_kw 

246 Additional arguments passed to `bartz.mcmcloop.run_mcmc`. 

247 

248 Attributes 

249 ---------- 

250 offset : Float32[Array, ''] 

251 The prior mean of the latent mean function. 

252 sigest : Float32[Array, ''] | None 

253 The estimated standard deviation of the error used to set `lamda`. 

254 yhat_test : Float32[Array, 'ndpost m'] | None 

255 The conditional posterior mean at `x_test` for each MCMC iteration. 

256 

257 Notes 

258 ----- 

259 This interface imitates the function ``mc_gbart`` from the R package `BART 

260 <https://cran.r-project.org/package=BART>`_, but with these differences: 

261 

262 - If `x_train` and `x_test` are matrices, they have one predictor per row 

263 instead of per column. 

264 - If ``usequants=False``, R BART switches to quantiles anyway if there are 

265 less predictor values than the required number of bins, while bartz 

266 always follows the specification. 

267 - Some functionality is missing. 

268 - The error variance parameter is called `lamda` instead of `lambda`. 

269 - There are some additional attributes, and some missing. 

270 - The trees have a maximum depth. 

271 - `rm_const` refers to predictors without decision rules instead of 

272 predictors that are constant in `x_train`. 

273 - If `rm_const=True` and some variables are dropped, the predictors 

274 matrix/dataframe passed to `predict` should still include them. 

275 

276 References 

277 ---------- 

278 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for 

279 High-Dimensional Prediction and Variable Selection”. In: Journal of the 

280 American Statistical Association 113.522, pp. 626-636. 

281 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

282 Bayesian additive regression trees," The Annals of Applied Statistics, 

283 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

284 """ 

285 

286 _main_trace: mcmcloop.MainTrace 1ab

287 _burnin_trace: mcmcloop.BurninTrace 1ab

288 _mcmc_state: mcmcstep.State 1ab

289 _splits: Real[Array, 'p max_num_splits'] 1ab

290 _x_train_fmt: Any = field(static=True) 1ab

291 

292 ndpost: int = field(static=True) 1ab

293 offset: Float32[Array, ''] 1ab

294 sigest: Float32[Array, ''] | None = None 1ab

295 yhat_test: Float32[Array, 'ndpost m'] | None = None 1ab

296 

297 def __init__( 1ab

298 self, 

299 x_train: Real[Array, 'p n'] | DataFrame, 

300 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, 

301 *, 

302 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

303 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

304 sparse: bool = False, 

305 theta: FloatLike | None = None, 

306 a: FloatLike = 0.5, 

307 b: FloatLike = 1.0, 

308 rho: FloatLike | None = None, 

309 xinfo: Float[Array, 'p n'] | None = None, 

310 usequants: bool = False, 

311 rm_const: bool | None = True, 

312 sigest: FloatLike | None = None, 

313 sigdf: FloatLike = 3.0, 

314 sigquant: FloatLike = 0.9, 

315 k: FloatLike = 2.0, 

316 power: FloatLike = 2.0, 

317 base: FloatLike = 0.95, 

318 lamda: FloatLike | None = None, 

319 tau_num: FloatLike | None = None, 

320 offset: FloatLike | None = None, 

321 w: Float[Array, ' n'] | None = None, 

322 ntree: int | None = None, 

323 numcut: int = 100, 

324 ndpost: int = 1000, 

325 nskip: int = 100, 

326 keepevery: int | None = None, 

327 printevery: int | None = 100, 

328 mc_cores: int = 2, 

329 seed: int | Key[Array, ''] = 0, 

330 maxdepth: int = 6, 

331 init_kw: dict | None = None, 

332 run_mcmc_kw: dict | None = None, 

333 ): 

334 # check data and put it in the right format 

335 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab

336 y_train = self._process_response_input(y_train) 1ab

337 self._check_same_length(x_train, y_train) 1ab

338 if w is not None: 1ab

339 w = self._process_response_input(w) 1ab

340 self._check_same_length(x_train, w) 1ab

341 

342 # check data types are correct for continuous/binary regression 

343 self._check_type_settings(y_train, type, w) 1ab

344 # from here onwards, the type is determined by y_train.dtype == bool 

345 

346 # set defaults that depend on type of regression 

347 if ntree is None: 1ab

348 ntree = 50 if y_train.dtype == bool else 200 1ab

349 if keepevery is None: 1ab

350 keepevery = 10 if y_train.dtype == bool else 1 1ab

351 

352 # process sparsity settings 

353 theta, a, b, rho = self._process_sparsity_settings( 1ab

354 x_train, sparse, theta, a, b, rho 

355 ) 

356 

357 # process "standardization" settings 

358 offset = self._process_offset_settings(y_train, offset) 1ab

359 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab

360 lamda, sigest = self._process_error_variance_settings( 1ab

361 x_train, y_train, sigest, sigdf, sigquant, lamda 

362 ) 

363 

364 # determine splits 

365 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1ab

366 x_train = self._bin_predictors(x_train, splits) 1ab

367 

368 # setup and run mcmc 

369 initial_state = self._setup_mcmc( 1ab

370 x_train, 

371 y_train, 

372 offset, 

373 w, 

374 max_split, 

375 lamda, 

376 sigma_mu, 

377 sigdf, 

378 power, 

379 base, 

380 maxdepth, 

381 ntree, 

382 init_kw, 

383 rm_const, 

384 theta, 

385 a, 

386 b, 

387 rho, 

388 ) 

389 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab

390 initial_state, 

391 mc_cores, 

392 ndpost, 

393 nskip, 

394 keepevery, 

395 printevery, 

396 seed, 

397 run_mcmc_kw, 

398 sparse, 

399 ) 

400 

401 # set public attributes 

402 self.offset = final_state.offset # from the state because of buffer donation 1ab

403 self.ndpost = main_trace.grow_prop_count.size 1ab

404 self.sigest = sigest 1ab

405 

406 # set private attributes 

407 self._main_trace = main_trace 1ab

408 self._burnin_trace = burnin_trace 1ab

409 self._mcmc_state = final_state 1ab

410 self._splits = splits 1ab

411 self._x_train_fmt = x_train_fmt 1ab

412 

413 # predict at test points 

414 if x_test is not None: 1ab

415 self.yhat_test = self.predict(x_test) 1ab

416 

417 @cached_property 1ab

418 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 1ab

419 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

420 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1ab

421 return None 1ab

422 else: 

423 return ndtr(self.yhat_test) 1ab

424 

425 @cached_property 1ab

426 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 1ab

427 """The marginal posterior probability of y being True at `x_test`.""" 

428 if self.prob_test is None: 1ab

429 return None 1ab

430 else: 

431 return self.prob_test.mean(axis=0) 1ab

432 

433 @cached_property 1ab

434 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 1ab

435 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

436 if self._mcmc_state.y.dtype == bool: 1ab

437 return ndtr(self.yhat_train) 1ab

438 else: 

439 return None 1ab

440 

441 @cached_property 1ab

442 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 1ab

443 """The marginal posterior probability of y being True at `x_train`.""" 

444 if self.prob_train is None: 1ab

445 return None 1ab

446 else: 

447 return self.prob_train.mean(axis=0) 1ab

448 

449 @cached_property 1ab

450 def sigma( 1ab

451 self, 

452 ) -> ( 

453 Float32[Array, ' nskip+ndpost'] 

454 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores'] 

455 | None 

456 ): 

457 """The standard deviation of the error, including burn-in samples.""" 

458 if self._burnin_trace.sigma2 is None: 1ab

459 return None 1ab

460 assert self._main_trace.sigma2 is not None 1ab

461 sigma = jnp.sqrt( 1ab

462 jnp.concatenate( 

463 [self._burnin_trace.sigma2, self._main_trace.sigma2], axis=1 

464 ) 

465 ) 

466 sigma = sigma.T 1ab

467 _, mc_cores = sigma.shape 1ab

468 if mc_cores == 1: 1ab

469 sigma = sigma.squeeze(1) 1ab

470 return sigma 1ab

471 

472 @cached_property 1ab

473 def sigma_mean(self) -> Float32[Array, ''] | None: 1ab

474 """The mean of `sigma`, only over the post-burnin samples.""" 

475 if self.sigma is None: 1ab

476 return None 1ab

477 _, nskip = self._burnin_trace.grow_prop_count.shape 1ab

478 return self.sigma[nskip:, ...].mean() 1ab

479 

480 @cached_property 1ab

481 def varcount(self) -> Int32[Array, 'ndpost p']: 1ab

482 """Histogram of predictor usage for decision rules in the trees.""" 

483 return self._compute_varcount_multichain_flattened( 1ab

484 self._mcmc_state.forest.max_split.size, self._main_trace 

485 ) 

486 

487 @staticmethod 1ab

488 @partial(jax.vmap, in_axes=(None, 0)) 1ab

489 def _compute_varcount_multichain( 1ab

490 p: int, main_trace: mcmcloop.MainTrace 

491 ) -> Int32[Array, 'mc_cores ndpost/mc_cores p']: 

492 return mcmcloop.compute_varcount(p, main_trace) 1ab

493 

494 @classmethod 1ab

495 @partial(jax.jit, static_argnums=(0, 1)) 1ab

496 def _compute_varcount_multichain_flattened( 1ab

497 cls, p: int, main_trace: mcmcloop.MainTrace 

498 ) -> Int32[Array, 'ndpost p']: 

499 return cls._compute_varcount_multichain(p, main_trace).reshape(-1, p) 1ab

500 

501 @cached_property 1ab

502 def varcount_mean(self) -> Float32[Array, ' p']: 1ab

503 """Average of `varcount` across MCMC iterations.""" 

504 return self.varcount.mean(axis=0) 1ab

505 

506 @cached_property 1ab

507 def varprob(self) -> Float32[Array, 'ndpost p']: 1ab

508 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

509 max_split = self._mcmc_state.forest.max_split 1ab

510 p = max_split.size 1ab

511 varprob = self._main_trace.varprob 1ab

512 if varprob is None: 1ab

513 peff = jnp.count_nonzero(max_split) 1ab

514 varprob = jnp.where(max_split, 1 / peff, 0) 1ab

515 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1ab

516 else: 

517 varprob = varprob.reshape(-1, p) 1ab

518 return varprob 1ab

519 

520 @cached_property 1ab

521 def varprob_mean(self) -> Float32[Array, ' p']: 1ab

522 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

523 return self.varprob.mean(axis=0) 1ab

524 

525 @cached_property 1ab

526 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 1ab

527 """The marginal posterior mean at `x_test`. 

528 

529 Not defined with binary regression because it's error-prone, typically 

530 the right thing to consider would be `prob_test_mean`. 

531 """ 

532 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1ab

533 return None 1ab

534 else: 

535 return self.yhat_test.mean(axis=0) 1ab

536 

537 @cached_property 1ab

538 def yhat_train(self) -> Float32[Array, 'ndpost n']: 1ab

539 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

540 x_train = self._mcmc_state.X 1ab

541 return self._predict(x_train) 1ab

542 

543 @cached_property 1ab

544 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 1ab

545 """The marginal posterior mean at `x_train`. 

546 

547 Not defined with binary regression because it's error-prone, typically 

548 the right thing to consider would be `prob_train_mean`. 

549 """ 

550 if self._mcmc_state.y.dtype == bool: 1ab

551 return None 1ab

552 else: 

553 return self.yhat_train.mean(axis=0) 1ab

554 

555 def predict( 1ab

556 self, x_test: Real[Array, 'p m'] | DataFrame 

557 ) -> Float32[Array, 'ndpost m']: 

558 """ 

559 Compute the posterior mean at `x_test` for each MCMC iteration. 

560 

561 Parameters 

562 ---------- 

563 x_test 

564 The test predictors. 

565 

566 Returns 

567 ------- 

568 The conditional posterior mean at `x_test` for each MCMC iteration. 

569 

570 Raises 

571 ------ 

572 ValueError 

573 If `x_test` has a different format than `x_train`. 

574 """ 

575 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab

576 if x_test_fmt != self._x_train_fmt: 1ab

577 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1ab

578 raise ValueError(msg) 1ab

579 x_test = self._bin_predictors(x_test, self._splits) 1ab

580 return self._predict(x_test) 1ab

581 

582 @staticmethod 1ab

583 def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: 1ab

584 if hasattr(x, 'columns'): 1ab

585 fmt = dict(kind='dataframe', columns=x.columns) 1ab

586 x = x.to_numpy().T 1ab

587 else: 

588 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab

589 x = jnp.asarray(x) 1ab

590 assert x.ndim == 2 1ab

591 return x, fmt 1ab

592 

593 @staticmethod 1ab

594 def _process_response_input(y) -> Shaped[Array, ' n']: 1ab

595 if hasattr(y, 'to_numpy'): 1ab

596 y = y.to_numpy() 1ab

597 y = jnp.asarray(y) 1ab

598 assert y.ndim == 1 1ab

599 return y 1ab

600 

601 @staticmethod 1ab

602 def _check_same_length(x1, x2): 1ab

603 get_length = lambda x: x.shape[-1] 1ab

604 assert get_length(x1) == get_length(x2) 1ab

605 

606 @staticmethod 1ab

607 def _process_error_variance_settings( 1ab

608 x_train, y_train, sigest, sigdf, sigquant, lamda 

609 ) -> tuple[Float32[Array, ''] | None, ...]: 

610 if y_train.dtype == bool: 1ab

611 if sigest is not None: 611 ↛ 612line 611 didn't jump to line 612 because the condition on line 611 was never true1ab

612 msg = 'Let `sigest=None` for binary regression' 

613 raise ValueError(msg) 

614 if lamda is not None: 614 ↛ 615line 614 didn't jump to line 615 because the condition on line 614 was never true1ab

615 msg = 'Let `lamda=None` for binary regression' 

616 raise ValueError(msg) 

617 return None, None 1ab

618 elif lamda is not None: 618 ↛ 619line 618 didn't jump to line 619 because the condition on line 618 was never true1ab

619 if sigest is not None: 

620 msg = 'Let `sigest=None` if `lamda` is specified' 

621 raise ValueError(msg) 

622 return lamda, None 

623 else: 

624 if sigest is not None: 624 ↛ 625line 624 didn't jump to line 625 because the condition on line 624 was never true1ab

625 sigest2 = jnp.square(sigest) 

626 elif y_train.size < 2: 1ab

627 sigest2 = 1 1ab

628 elif y_train.size <= x_train.shape[0]: 1ab

629 sigest2 = jnp.var(y_train) 1ab

630 else: 

631 x_centered = x_train.T - x_train.mean(axis=1) 1ab

632 y_centered = y_train - y_train.mean() 1ab

633 # centering is equivalent to adding an intercept column 

634 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab

635 chisq = chisq.squeeze(0) 1ab

636 dof = len(y_train) - rank 1ab

637 sigest2 = chisq / dof 1ab

638 alpha = sigdf / 2 1ab

639 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1ab

640 invchi2rid = invchi2 * sigdf 1ab

641 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab

642 

643 @staticmethod 1ab

644 def _check_type_settings(y_train, type, w): # noqa: A002 1ab

645 match type: 1ab

646 case 'wbart': 1ab

647 if y_train.dtype != jnp.float32: 647 ↛ 648line 647 didn't jump to line 648 because the condition on line 647 was never true1ab

648 msg = ( 

649 'Continuous regression requires y_train.dtype=float32,' 

650 f' got {y_train.dtype=} instead.' 

651 ) 

652 raise TypeError(msg) 1a

653 case 'pbart': 653 ↛ 663line 653 didn't jump to line 663 because the pattern on line 653 always matched1ab

654 if w is not None: 654 ↛ 655line 654 didn't jump to line 655 because the condition on line 654 was never true1ab

655 msg = 'Binary regression does not support weights, set `w=None`' 

656 raise ValueError(msg) 

657 if y_train.dtype != bool: 657 ↛ 658line 657 didn't jump to line 658 because the condition on line 657 was never true1ab

658 msg = ( 

659 'Binary regression requires y_train.dtype=bool,' 

660 f' got {y_train.dtype=} instead.' 

661 ) 

662 raise TypeError(msg) 1a

663 case _: 

664 msg = f'Invalid {type=}' 

665 raise ValueError(msg) 

666 

667 @staticmethod 1ab

668 def _process_sparsity_settings( 1ab

669 x_train: Real[Array, 'p n'], 

670 sparse: bool, 

671 theta: FloatLike | None, 

672 a: FloatLike, 

673 b: FloatLike, 

674 rho: FloatLike | None, 

675 ) -> ( 

676 tuple[None, None, None, None] 

677 | tuple[FloatLike, None, None, None] 

678 | tuple[None, FloatLike, FloatLike, FloatLike] 

679 ): 

680 if not sparse: 1ab

681 return None, None, None, None 1ab

682 elif theta is not None: 1ab

683 return theta, None, None, None 1ab

684 else: 

685 if rho is None: 685 ↛ 688line 685 didn't jump to line 688 because the condition on line 685 was always true1ab

686 p, _ = x_train.shape 1ab

687 rho = float(p) 1ab

688 return None, a, b, rho 1ab

689 

690 @staticmethod 1ab

691 def _process_offset_settings( 1ab

692 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

693 offset: float | Float32[Any, ''] | None, 

694 ) -> Float32[Array, '']: 

695 if offset is not None: 695 ↛ 696line 695 didn't jump to line 696 because the condition on line 695 was never true1ab

696 return jnp.asarray(offset) 

697 elif y_train.size < 1: 1ab

698 return jnp.array(0.0) 1ab

699 else: 

700 mean = y_train.mean() 1ab

701 

702 if y_train.dtype == bool: 1ab

703 bound = 1 / (1 + y_train.size) 1ab

704 mean = jnp.clip(mean, bound, 1 - bound) 1ab

705 return ndtri(mean) 1ab

706 else: 

707 return mean 1ab

708 

709 @staticmethod 1ab

710 def _process_leaf_sdev_settings( 1ab

711 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

712 k: float, 

713 ntree: int, 

714 tau_num: FloatLike | None, 

715 ): 

716 if tau_num is None: 716 ↛ 724line 716 didn't jump to line 724 because the condition on line 716 was always true1ab

717 if y_train.dtype == bool: 1ab

718 tau_num = 3.0 1ab

719 elif y_train.size < 2: 1ab

720 tau_num = 1.0 1ab

721 else: 

722 tau_num = (y_train.max() - y_train.min()) / 2 1ab

723 

724 return tau_num / (k * math.sqrt(ntree)) 1ab

725 

726 @staticmethod 1ab

727 def _determine_splits( 1ab

728 x_train: Real[Array, 'p n'], 

729 usequants: bool, 

730 numcut: int, 

731 xinfo: Float[Array, 'p n'] | None, 

732 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

733 if xinfo is not None: 1ab

734 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1ab

735 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1ab

736 raise ValueError(msg) 1ab

737 return prepcovars.parse_xinfo(xinfo) 1ab

738 elif usequants: 1ab

739 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab

740 else: 

741 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab

742 

743 @staticmethod 1ab

744 def _bin_predictors(x, splits) -> UInt[Array, 'p n']: 1ab

745 return prepcovars.bin_predictors(x, splits) 1ab

746 

747 @staticmethod 1ab

748 def _setup_mcmc( 1ab

749 x_train: Real[Array, 'p n'], 

750 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

751 offset: Float32[Array, ''], 

752 w: Float[Array, ' n'] | None, 

753 max_split: UInt[Array, ' p'], 

754 lamda: Float32[Array, ''] | None, 

755 sigma_mu: FloatLike, 

756 sigdf: FloatLike, 

757 power: FloatLike, 

758 base: FloatLike, 

759 maxdepth: int, 

760 ntree: int, 

761 init_kw: dict[str, Any] | None, 

762 rm_const: bool | None, 

763 theta: FloatLike | None, 

764 a: FloatLike | None, 

765 b: FloatLike | None, 

766 rho: FloatLike | None, 

767 ): 

768 depth = jnp.arange(maxdepth - 1) 1ab

769 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab

770 

771 if y_train.dtype == bool: 1ab

772 sigma2_alpha = None 1ab

773 sigma2_beta = None 1ab

774 else: 

775 sigma2_alpha = sigdf / 2 1ab

776 sigma2_beta = lamda * sigma2_alpha 1ab

777 

778 kw = dict( 1ab

779 X=x_train, 

780 # copy y_train because it's going to be donated in the mcmc loop 

781 y=jnp.array(y_train), 

782 offset=offset, 

783 error_scale=w, 

784 max_split=max_split, 

785 num_trees=ntree, 

786 p_nonterminal=p_nonterminal, 

787 sigma_mu2=jnp.square(sigma_mu), 

788 sigma2_alpha=sigma2_alpha, 

789 sigma2_beta=sigma2_beta, 

790 min_points_per_decision_node=10, 

791 min_points_per_leaf=5, 

792 theta=theta, 

793 a=a, 

794 b=b, 

795 rho=rho, 

796 ) 

797 

798 if rm_const is None: 1ab

799 kw.update(filter_splitless_vars=False) 1ab

800 elif rm_const: 800 ↛ 803line 800 didn't jump to line 803 because the condition on line 800 was always true1ab

801 kw.update(filter_splitless_vars=True) 1ab

802 else: 

803 n_empty = jnp.count_nonzero(max_split == 0) 

804 if n_empty: 

805 msg = f'There are {n_empty}/{max_split.size} predictors without decision rules' 

806 raise ValueError(msg) 

807 kw.update(filter_splitless_vars=False) 

808 

809 if init_kw is not None: 1ab

810 kw.update(init_kw) 1ab

811 

812 return mcmcstep.init(**kw) 1ab

813 

814 @classmethod 1ab

815 def _run_mcmc( 1ab

816 cls, 

817 mcmc_state: mcmcstep.State, 

818 mc_cores: int, 

819 ndpost: int, 

820 nskip: int, 

821 keepevery: int, 

822 printevery: int | None, 

823 seed: int | Integer[Array, ''] | Key[Array, ''], 

824 run_mcmc_kw: dict | None, 

825 sparse: bool, 

826 ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]: 

827 # prepare random generator seed 

828 if isinstance(seed, jax.Array) and jnp.issubdtype( 1ab

829 seed.dtype, jax.dtypes.prng_key 

830 ): 

831 key = seed 1ab

832 else: 

833 key = jax.random.key(seed) 1ab

834 

835 # round up ndpost 

836 ndpost = mc_cores * (ndpost // mc_cores + bool(ndpost % mc_cores)) 1ab

837 

838 # prepare arguments 

839 kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1ab

840 kw.update( 1ab

841 mcmcloop.make_default_callback( 

842 dot_every=None if printevery is None or printevery == 1 else 1, 

843 report_every=printevery, 

844 sparse_on_at=nskip // 2 if sparse else None, 

845 ) 

846 ) 

847 if run_mcmc_kw is not None: 1ab

848 kw.update(run_mcmc_kw) 1ab

849 

850 if mc_cores == 1: 1ab

851 return cls._single_run_mcmc(key, mcmc_state, ndpost, **kw) 1ab

852 else: 

853 keys = jax.random.split(key, mc_cores) 1ab

854 return cls._vmapped_run_mcmc(keys, mcmc_state, ndpost // mc_cores, **kw) 1ab

855 

856 @classmethod 1ab

857 def _single_run_mcmc( 1ab

858 cls, key: Key[Array, ''], bart: mcmcstep.State, *args, **kwargs 

859 ): 

860 out = mcmcloop.run_mcmc(key, bart, *args, **kwargs) 1ab

861 axes = cls._vmap_axes_for_state(bart) 1ab

862 return jax.vmap(lambda x: x, in_axes=None, out_axes=(axes, 0, 0), axis_size=1)( 1ab

863 out 

864 ) 

865 

866 @classmethod 1ab

867 def _vmapped_run_mcmc( 1ab

868 cls, keys: Key[Array, ' mc_cores'], bart: mcmcstep.State, *args, **kwargs 

869 ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]: 

870 bart_axes = cls._vmap_axes_for_state(bart) 1ab

871 

872 barts = jax.vmap( 1ab

873 lambda x: x, in_axes=None, out_axes=bart_axes, axis_size=keys.size 

874 )(bart) 

875 

876 @partial(jax.vmap, in_axes=(0, bart_axes), out_axes=(bart_axes, 0, 0)) 1ab

877 def _partial_vmapped_run_mcmc(key, bart): 1ab

878 return mcmcloop.run_mcmc(key, bart, *args, **kwargs) 1ab

879 

880 return _partial_vmapped_run_mcmc(keys, barts) 1ab

881 

882 @staticmethod 1ab

883 def _vmap_axes_for_state(state: mcmcstep.State) -> mcmcstep.State: 1ab

884 def choose_vmap_index(path, _) -> Literal[0, None]: 1ab

885 no_vmap_attrs = ( 1ab

886 '.X', 

887 '.y', 

888 '.offset', 

889 '.prec_scale', 

890 '.sigma2_alpha', 

891 '.sigma2_beta', 

892 '.forest.max_split', 

893 '.forest.blocked_vars', 

894 '.forest.p_nonterminal', 

895 '.forest.p_propose_grow', 

896 '.forest.min_points_per_decision_node', 

897 '.forest.min_points_per_leaf', 

898 '.forest.sigma_mu2', 

899 '.forest.a', 

900 '.forest.b', 

901 '.forest.rho', 

902 ) 

903 str_path = ''.join(map(str, path)) 1ab

904 if str_path in no_vmap_attrs: 1ab

905 return None 1ab

906 else: 

907 return 0 1ab

908 

909 return map_with_path(choose_vmap_index, state) 1ab

910 

911 def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: 1ab

912 return self._evaluate_chains_flattened(self._main_trace, x) 1ab

913 

914 @classmethod 1ab

915 @partial(jax.jit, static_argnums=(0,)) 1ab

916 def _evaluate_chains_flattened( 1ab

917 cls, trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] 

918 ) -> Float32[Array, 'ndpost m']: 

919 out = cls._evaluate_chains(trace, x) 1ab

920 mc_cores, ndpost_per_chain, m = out.shape 1ab

921 return out.reshape(mc_cores * ndpost_per_chain, m) 1ab

922 

923 @staticmethod 1ab

924 @partial(jax.vmap, in_axes=(0, None)) 1ab

925 def _evaluate_chains( 1ab

926 trace: mcmcloop.MainTrace, x: UInt[Array, 'p m'] 

927 ) -> Float32[Array, 'mc_cores ndpost/mc_cores m']: 

928 return mcmcloop.evaluate_trace(trace, x) 1ab

929 

930 

931class gbart(mc_gbart): 1ab

932 """Subclass of `mc_gbart` that forces `mc_cores=1`.""" 

933 

934 def __init__(self, *args, **kwargs): 1ab

935 if 'mc_cores' in kwargs: 1ab

936 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 1ab

937 raise TypeError(msg) 1ab

938 kwargs.update(mc_cores=1) 1ab

939 super().__init__(*args, **kwargs) 1ab