Coverage for src/bartz/BART.py: 88%

258 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/BART.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement a class `gbart` that mimics the R BART package.""" 

26 

27import math 1ab

28from collections.abc import Sequence 1ab

29from functools import cached_property 1ab

30from typing import Any, Literal, Protocol 1ab

31 

32import jax 1ab

33import jax.numpy as jnp 1ab

34from equinox import Module, field 1ab

35from jax.scipy.special import ndtr 1ab

36from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, Shaped, UInt 1ab

37from numpy import ndarray 1ab

38 

39from bartz import mcmcloop, mcmcstep, prepcovars 1ab

40from bartz.jaxext.scipy.special import ndtri 1ab

41from bartz.jaxext.scipy.stats import invgamma 1ab

42 

43FloatLike = float | Float[Any, ''] 1ab

44 

45 

46class DataFrame(Protocol): 1ab

47 """DataFrame duck-type for `gbart`. 

48 

49 Attributes 

50 ---------- 

51 columns : Sequence[str] 

52 The names of the columns. 

53 """ 

54 

55 columns: Sequence[str] 1ab

56 

57 def to_numpy(self) -> ndarray: 1ab

58 """Convert the dataframe to a 2d numpy array with columns on the second axis.""" 

59 ... 

60 

61 

62class Series(Protocol): 1ab

63 """Series duck-type for `gbart`. 

64 

65 Attributes 

66 ---------- 

67 name : str | None 

68 The name of the series. 

69 """ 

70 

71 name: str | None 1ab

72 

73 def to_numpy(self) -> ndarray: 1ab

74 """Convert the series to a 1d numpy array.""" 

75 ... 

76 

77 

78class gbart(Module): 1ab

79 """ 

80 Nonparametric regression with Bayesian Additive Regression Trees (BART). 

81 

82 Regress `y_train` on `x_train` with a latent mean function represented as 

83 a sum of decision trees. The inference is carried out by sampling the 

84 posterior distribution of the tree ensemble with an MCMC. 

85 

86 Parameters 

87 ---------- 

88 x_train 

89 The training predictors. 

90 y_train 

91 The training responses. 

92 x_test 

93 The test predictors. 

94 type 

95 The type of regression. 'wbart' for continuous regression, 'pbart' for 

96 binary regression with probit link. 

97 xinfo 

98 A matrix with the cutpoins to use to bin each predictor. If not 

99 specified, it is generated automatically according to `usequants` and 

100 `numcut`. 

101 

102 Each row shall contain a sorted list of cutpoints for a predictor. If 

103 there are less cutpoints than the number of columns in the matrix, 

104 fill the remaining cells with NaN. 

105 

106 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

107 usequants 

108 Whether to use predictors quantiles instead of a uniform grid to bin 

109 predictors. Ignored if `xinfo` is specified. 

110 rm_const 

111 How to treat predictors with no associated decision rules (i.e., there 

112 are no available cutpoints for that predictor). If `True` (default), 

113 they are ignored. If `False`, an error is raised if there are any. If 

114 `None`, no check is performed, and the output of the MCMC may not make 

115 sense if there are predictors without cutpoints. The option `None` is 

116 provided only to allow jax tracing. 

117 sigest 

118 An estimate of the residual standard deviation on `y_train`, used to set 

119 `lamda`. If not specified, it is estimated by linear regression (with 

120 intercept, and without taking into account `w`). If `y_train` has less 

121 than two elements, it is set to 1. If n <= p, it is set to the standard 

122 deviation of `y_train`. Ignored if `lamda` is specified. 

123 sigdf 

124 The degrees of freedom of the scaled inverse-chisquared prior on the 

125 noise variance. 

126 sigquant 

127 The quantile of the prior on the noise variance that shall match 

128 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

129 k 

130 The inverse scale of the prior standard deviation on the latent mean 

131 function, relative to half the observed range of `y_train`. If `y_train` 

132 has less than two elements, `k` is ignored and the scale is set to 1. 

133 power 

134 base 

135 Parameters of the prior on tree node generation. The probability that a 

136 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

137 power``. 

138 lamda 

139 The prior harmonic mean of the error variance. (The harmonic mean of x 

140 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

141 `sigquant`. 

142 tau_num 

143 The numerator in the expression that determines the prior standard 

144 deviation of leaves. If not specified, default to ``(max(y_train) - 

145 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

146 continuous regression, and 3 for binary regression. 

147 offset 

148 The prior mean of the latent mean function. If not specified, it is set 

149 to the mean of `y_train` for continuous regression, and to 

150 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

151 `offset` is set to 0. With binary regression, if `y_train` is all 

152 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

153 ``Phi^-1(n/(n+1))``, respectively. 

154 w 

155 Coefficients that rescale the error standard deviation on each 

156 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

157 datapoints. Note: `w` is ignored in the automatic determination of 

158 `sigest`, so either the weights should be O(1), or `sigest` should be 

159 specified by the user. 

160 ntree 

161 The number of trees used to represent the latent mean function. By 

162 default 200 for continuous regression and 50 for binary regression. 

163 numcut 

164 If `usequants` is `False`: the exact number of cutpoints used to bin the 

165 predictors, ranging between the minimum and maximum observed values 

166 (excluded). 

167 

168 If `usequants` is `True`: the maximum number of cutpoints to use for 

169 binning the predictors. Each predictor is binned such that its 

170 distribution in `x_train` is approximately uniform across bins. The 

171 number of bins is at most the number of unique values appearing in 

172 `x_train`, or ``numcut + 1``. 

173 

174 Before running the algorithm, the predictors are compressed to the 

175 smallest integer type that fits the bin indices, so `numcut` is best set 

176 to the maximum value of an unsigned integer type, like 255. 

177 

178 Ignored if `xinfo` is specified. 

179 ndpost 

180 The number of MCMC samples to save, after burn-in. 

181 nskip 

182 The number of initial MCMC samples to discard as burn-in. 

183 keepevery 

184 The thinning factor for the MCMC samples, after burn-in. By default, 1 

185 for continuous regression and 10 for binary regression. 

186 printevery 

187 The number of iterations (including thinned-away ones) between each log 

188 line. Set to `None` to disable logging. 

189 

190 `printevery` has a few unexpected side effects. On cpu, interrupting 

191 with ^C halts the MCMC only on the next log. And the total number of 

192 iterations is a multiple of `printevery`, so if ``nskip + keepevery * 

193 ndpost`` is not a multiple of `printevery`, some of the last iterations 

194 will not be saved. 

195 seed 

196 The seed for the random number generator. 

197 maxdepth 

198 The maximum depth of the trees. This is 1-based, so with the default 

199 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

200 init_kw 

201 Additional arguments passed to `bartz.mcmcstep.init`. 

202 run_mcmc_kw 

203 Additional arguments passed to `bartz.mcmcloop.run_mcmc`. 

204 

205 Attributes 

206 ---------- 

207 offset : Float32[Array, ''] 

208 The prior mean of the latent mean function. 

209 sigest : Float32[Array, ''] | None 

210 The estimated standard deviation of the error used to set `lamda`. 

211 sigma : Float32[Array, 'nskip+ndpost'] | None 

212 The standard deviation of the error, including burn-in samples. 

213 yhat_test : Float32[Array, 'ndpost m'] | None 

214 The conditional posterior mean at `x_test` for each MCMC iteration. 

215 

216 Notes 

217 ----- 

218 This interface imitates the function ``gbart`` from the R package `BART 

219 <https://cran.r-project.org/package=BART>`_, but with these differences: 

220 

221 - If `x_train` and `x_test` are matrices, they have one predictor per row 

222 instead of per column. 

223 - If ``usequants=False``, R BART switches to quantiles anyway if there are 

224 less predictor values than the required number of bins, while bartz 

225 always follows the specification. 

226 - The error variance parameter is called `lamda` instead of `lambda`. 

227 - Some functionality is missing (e.g., variable selection). 

228 - There are some additional attributes, and some missing. 

229 - The trees have a maximum depth. 

230 - `rm_const` refers to predictors without decision rules instead of 

231 predictors that are constant in `x_train`. 

232 

233 """ 

234 

235 _main_trace: mcmcloop.MainTrace 1ab

236 _mcmc_state: mcmcstep.State 1ab

237 _splits: Real[Array, 'p max_num_splits'] 1ab

238 _x_train_fmt: Any = field(static=True) 1ab

239 

240 ndpost: int = field(static=True) 1ab

241 offset: Float32[Array, ''] 1ab

242 sigma: Float32[Array, ' nskip+ndpost'] | None = None 1ab

243 sigest: Float32[Array, ''] | None = None 1ab

244 yhat_test: Float32[Array, 'ndpost m'] | None = None 1ab

245 

246 def __init__( 1ab

247 self, 

248 x_train: Real[Array, 'p n'] | DataFrame, 

249 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, 

250 *, 

251 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

252 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

253 xinfo: Float[Array, 'p n'] | None = None, 

254 usequants: bool = False, 

255 rm_const: bool | None = True, 

256 sigest: FloatLike | None = None, 

257 sigdf: FloatLike = 3.0, 

258 sigquant: FloatLike = 0.9, 

259 k: FloatLike = 2.0, 

260 power: FloatLike = 2.0, 

261 base: FloatLike = 0.95, 

262 lamda: FloatLike | None = None, 

263 tau_num: FloatLike | None = None, 

264 offset: FloatLike | None = None, 

265 w: Float[Array, ' n'] | None = None, 

266 ntree: int | None = None, 

267 numcut: int = 100, 

268 ndpost: int = 1000, 

269 nskip: int = 100, 

270 keepevery: int | None = None, 

271 printevery: int | None = 100, 

272 seed: int | Key[Array, ''] = 0, 

273 maxdepth: int = 6, 

274 init_kw: dict | None = None, 

275 run_mcmc_kw: dict | None = None, 

276 ): 

277 # check data and put it in the right format 

278 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab

279 y_train, _ = self._process_response_input(y_train) 1ab

280 self._check_same_length(x_train, y_train) 1ab

281 if w is not None: 1ab

282 w, _ = self._process_response_input(w) 1ab

283 self._check_same_length(x_train, w) 1ab

284 

285 # check data types are correct for continuous/binary regression 

286 self._check_type_settings(y_train, type, w) 1ab

287 # from here onwards, the type is determined by y_train.dtype == bool 

288 

289 # set defaults that depend on type of regression 

290 if ntree is None: 1ab

291 ntree = 50 if y_train.dtype == bool else 200 1ab

292 if keepevery is None: 1ab

293 keepevery = 10 if y_train.dtype == bool else 1 1ab

294 

295 # process "standardization" settings 

296 offset = self._process_offset_settings(y_train, offset) 1ab

297 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab

298 lamda, sigest = self._process_error_variance_settings( 1ab

299 x_train, y_train, sigest, sigdf, sigquant, lamda 

300 ) 

301 

302 # determine splits 

303 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1ab

304 x_train = self._bin_predictors(x_train, splits) 1ab

305 

306 # setup and run mcmc 

307 initial_state = self._setup_mcmc( 1ab

308 x_train, 

309 y_train, 

310 offset, 

311 w, 

312 max_split, 

313 lamda, 

314 sigma_mu, 

315 sigdf, 

316 power, 

317 base, 

318 maxdepth, 

319 ntree, 

320 init_kw, 

321 rm_const, 

322 ) 

323 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab

324 initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw 

325 ) 

326 

327 # set public attributes 

328 self.offset = final_state.offset # from the state because of buffer donation 1ab

329 self.ndpost = ndpost 1ab

330 self.sigest = sigest 1ab

331 self.sigma = self._extract_sigma(burnin_trace, main_trace) 1ab

332 

333 # set private attributes 

334 self._main_trace = main_trace 1ab

335 self._mcmc_state = final_state 1ab

336 self._splits = splits 1ab

337 self._x_train_fmt = x_train_fmt 1ab

338 

339 # predict at test points 

340 if x_test is not None: 1ab

341 self.yhat_test = self.predict(x_test) 1ab

342 

343 @cached_property 1ab

344 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 1ab

345 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

346 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1ab

347 return None 1ab

348 else: 

349 return ndtr(self.yhat_test) 1ab

350 

351 @cached_property 1ab

352 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 1ab

353 """The marginal posterior probability of y being True at `x_test`.""" 

354 if self.prob_test is None: 1ab

355 return None 1ab

356 else: 

357 return self.prob_test.mean(axis=0) 1ab

358 

359 @cached_property 1ab

360 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 1ab

361 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

362 if self._mcmc_state.y.dtype == bool: 1ab

363 return ndtr(self.yhat_train) 1ab

364 else: 

365 return None 1ab

366 

367 @cached_property 1ab

368 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 1ab

369 """The marginal posterior probability of y being True at `x_train`.""" 

370 if self.prob_train is None: 1ab

371 return None 1ab

372 else: 

373 return self.prob_train.mean(axis=0) 1ab

374 

375 @cached_property 1ab

376 def sigma_mean(self) -> Float32[Array, ''] | None: 1ab

377 """The mean of `sigma`, only over the post-burnin samples.""" 

378 if self.sigma is None: 1ab

379 return None 1ab

380 else: 

381 return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0) 1ab

382 

383 @cached_property 1ab

384 def varcount(self) -> Int32[Array, 'ndpost p']: 1ab

385 """Histogram of predictor usage for decision rules in the trees.""" 

386 return mcmcloop.compute_varcount( 1ab

387 self._mcmc_state.forest.max_split.size, self._main_trace 

388 ) 

389 

390 @cached_property 1ab

391 def varcount_mean(self) -> Float32[Array, ' p']: 1ab

392 """Average of `varcount` across MCMC iterations.""" 

393 return self.varcount.mean(axis=0) 1ab

394 

395 @cached_property 1ab

396 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 1ab

397 """The marginal posterior mean at `x_test`. 

398 

399 Not defined with binary regression because it's error-prone, typically 

400 the right thing to consider would be `prob_test_mean`. 

401 """ 

402 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1ab

403 return None 1ab

404 else: 

405 return self.yhat_test.mean(axis=0) 1ab

406 

407 @cached_property 1ab

408 def yhat_train(self) -> Float32[Array, 'ndpost n']: 1ab

409 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

410 x_train = self._mcmc_state.X 1ab

411 return self._predict(x_train) 1ab

412 

413 @cached_property 1ab

414 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 1ab

415 """The marginal posterior mean at `x_train`. 

416 

417 Not defined with binary regression because it's error-prone, typically 

418 the right thing to consider would be `prob_train_mean`. 

419 """ 

420 if self._mcmc_state.y.dtype == bool: 1ab

421 return None 1ab

422 else: 

423 return self.yhat_train.mean(axis=0) 1ab

424 

425 def predict( 1ab

426 self, x_test: Real[Array, 'p m'] | DataFrame 

427 ) -> Float32[Array, 'ndpost m']: 

428 """ 

429 Compute the posterior mean at `x_test` for each MCMC iteration. 

430 

431 Parameters 

432 ---------- 

433 x_test 

434 The test predictors. 

435 

436 Returns 

437 ------- 

438 The conditional posterior mean at `x_test` for each MCMC iteration. 

439 

440 Raises 

441 ------ 

442 ValueError 

443 If `x_test` has a different format than `x_train`. 

444 """ 

445 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab

446 if x_test_fmt != self._x_train_fmt: 1ab

447 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1ab

448 raise ValueError(msg) 1ab

449 x_test = self._bin_predictors(x_test, self._splits) 1ab

450 return self._predict(x_test) 1ab

451 

452 @staticmethod 1ab

453 def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: 1ab

454 if hasattr(x, 'columns'): 1ab

455 fmt = dict(kind='dataframe', columns=x.columns) 1ab

456 x = x.to_numpy().T 1ab

457 else: 

458 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab

459 x = jnp.asarray(x) 1ab

460 assert x.ndim == 2 1ab

461 return x, fmt 1ab

462 

463 @staticmethod 1ab

464 def _process_response_input(y) -> tuple[Shaped[Array, ' n'], Any]: 1ab

465 if hasattr(y, 'to_numpy'): 1ab

466 fmt = dict(kind='series', name=y.name) 1ab

467 y = y.to_numpy() 1ab

468 else: 

469 fmt = dict(kind='array') 1ab

470 y = jnp.asarray(y) 1ab

471 assert y.ndim == 1 1ab

472 return y, fmt 1ab

473 

474 @staticmethod 1ab

475 def _check_same_length(x1, x2): 1ab

476 get_length = lambda x: x.shape[-1] 1ab

477 assert get_length(x1) == get_length(x2) 1ab

478 

479 @staticmethod 1ab

480 def _process_error_variance_settings( 1ab

481 x_train, y_train, sigest, sigdf, sigquant, lamda 

482 ) -> tuple[Float32[Array, ''] | None, ...]: 

483 if y_train.dtype == bool: 1ab

484 if sigest is not None: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true1ab

485 msg = 'Let `sigest=None` for binary regression' 

486 raise ValueError(msg) 

487 if lamda is not None: 487 ↛ 488line 487 didn't jump to line 488 because the condition on line 487 was never true1ab

488 msg = 'Let `lamda=None` for binary regression' 

489 raise ValueError(msg) 

490 return None, None 1ab

491 elif lamda is not None: 491 ↛ 492line 491 didn't jump to line 492 because the condition on line 491 was never true1ab

492 if sigest is not None: 

493 msg = 'Let `sigest=None` if `lamda` is specified' 

494 raise ValueError(msg) 

495 return lamda, None 

496 else: 

497 if sigest is not None: 497 ↛ 498line 497 didn't jump to line 498 because the condition on line 497 was never true1ab

498 sigest2 = jnp.square(sigest) 

499 elif y_train.size < 2: 1ab

500 sigest2 = 1 1ab

501 elif y_train.size <= x_train.shape[0]: 1ab

502 sigest2 = jnp.var(y_train) 1ab

503 else: 

504 x_centered = x_train.T - x_train.mean(axis=1) 1ab

505 y_centered = y_train - y_train.mean() 1ab

506 # centering is equivalent to adding an intercept column 

507 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab

508 chisq = chisq.squeeze(0) 1ab

509 dof = len(y_train) - rank 1ab

510 sigest2 = chisq / dof 1ab

511 alpha = sigdf / 2 1ab

512 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1ab

513 invchi2rid = invchi2 * sigdf 1ab

514 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab

515 

516 @staticmethod 1ab

517 def _check_type_settings(y_train, type, w): # noqa: A002 1ab

518 match type: 1ab

519 case 'wbart': 1ab

520 if y_train.dtype != jnp.float32: 520 ↛ 521line 520 didn't jump to line 521 because the condition on line 520 was never true1ab

521 msg = ( 

522 'Continuous regression requires y_train.dtype=float32,' 

523 f' got {y_train.dtype=} instead.' 

524 ) 

525 raise TypeError(msg) 1a

526 case 'pbart': 526 ↛ 536line 526 didn't jump to line 536 because the pattern on line 526 always matched1ab

527 if w is not None: 527 ↛ 528line 527 didn't jump to line 528 because the condition on line 527 was never true1ab

528 msg = 'Binary regression does not support weights, set `w=None`' 

529 raise ValueError(msg) 

530 if y_train.dtype != bool: 530 ↛ 531line 530 didn't jump to line 531 because the condition on line 530 was never true1ab

531 msg = ( 

532 'Binary regression requires y_train.dtype=bool,' 

533 f' got {y_train.dtype=} instead.' 

534 ) 

535 raise TypeError(msg) 1a

536 case _: 

537 msg = f'Invalid {type=}' 

538 raise ValueError(msg) 

539 

540 @staticmethod 1ab

541 def _process_offset_settings( 1ab

542 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

543 offset: float | Float32[Any, ''] | None, 

544 ) -> Float32[Array, '']: 

545 if offset is not None: 545 ↛ 546line 545 didn't jump to line 546 because the condition on line 545 was never true1ab

546 return jnp.asarray(offset) 

547 elif y_train.size < 1: 1ab

548 return jnp.array(0.0) 1ab

549 else: 

550 mean = y_train.mean() 1ab

551 

552 if y_train.dtype == bool: 1ab

553 bound = 1 / (1 + y_train.size) 1ab

554 mean = jnp.clip(mean, bound, 1 - bound) 1ab

555 return ndtri(mean) 1ab

556 else: 

557 return mean 1ab

558 

559 @staticmethod 1ab

560 def _process_leaf_sdev_settings( 1ab

561 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

562 k: float, 

563 ntree: int, 

564 tau_num: FloatLike | None, 

565 ): 

566 if tau_num is None: 566 ↛ 574line 566 didn't jump to line 574 because the condition on line 566 was always true1ab

567 if y_train.dtype == bool: 1ab

568 tau_num = 3.0 1ab

569 elif y_train.size < 2: 1ab

570 tau_num = 1.0 1ab

571 else: 

572 tau_num = (y_train.max() - y_train.min()) / 2 1ab

573 

574 return tau_num / (k * math.sqrt(ntree)) 1ab

575 

576 @staticmethod 1ab

577 def _determine_splits( 1ab

578 x_train: Real[Array, 'p n'], 

579 usequants: bool, 

580 numcut: int, 

581 xinfo: Float[Array, 'p n'] | None, 

582 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

583 if xinfo is not None: 1ab

584 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1ab

585 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1ab

586 raise ValueError(msg) 1ab

587 return prepcovars.parse_xinfo(xinfo) 1ab

588 elif usequants: 1ab

589 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab

590 else: 

591 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab

592 

593 @staticmethod 1ab

594 def _bin_predictors(x, splits) -> UInt[Array, 'p n']: 1ab

595 return prepcovars.bin_predictors(x, splits) 1ab

596 

597 @staticmethod 1ab

598 def _setup_mcmc( 1ab

599 x_train: Real[Array, 'p n'], 

600 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

601 offset: Float32[Array, ''], 

602 w: Float[Array, ' n'] | None, 

603 max_split: UInt[Array, ' p'], 

604 lamda: Float32[Array, ''] | None, 

605 sigma_mu: FloatLike, 

606 sigdf: FloatLike, 

607 power: FloatLike, 

608 base: FloatLike, 

609 maxdepth: int, 

610 ntree: int, 

611 init_kw: dict[str, Any] | None, 

612 rm_const: bool | None, 

613 ): 

614 depth = jnp.arange(maxdepth - 1) 1ab

615 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab

616 

617 if y_train.dtype == bool: 1ab

618 sigma2_alpha = None 1ab

619 sigma2_beta = None 1ab

620 else: 

621 sigma2_alpha = sigdf / 2 1ab

622 sigma2_beta = lamda * sigma2_alpha 1ab

623 

624 kw = dict( 1ab

625 X=x_train, 

626 # copy y_train because it's going to be donated in the mcmc loop 

627 y=jnp.array(y_train), 

628 offset=offset, 

629 error_scale=w, 

630 max_split=max_split, 

631 num_trees=ntree, 

632 p_nonterminal=p_nonterminal, 

633 sigma_mu2=jnp.square(sigma_mu), 

634 sigma2_alpha=sigma2_alpha, 

635 sigma2_beta=sigma2_beta, 

636 min_points_per_decision_node=10, 

637 min_points_per_leaf=5, 

638 ) 

639 

640 if rm_const is None: 1ab

641 kw.update(filter_splitless_vars=False) 1ab

642 elif rm_const: 642 ↛ 645line 642 didn't jump to line 645 because the condition on line 642 was always true1ab

643 kw.update(filter_splitless_vars=True) 1ab

644 else: 

645 n_empty = jnp.count_nonzero(max_split == 0) 

646 if n_empty: 

647 msg = f'There are {n_empty} predictors without decision rules' 

648 raise ValueError(msg) 

649 kw.update(filter_splitless_vars=False) 

650 

651 if init_kw is not None: 1ab

652 kw.update(init_kw) 1ab

653 

654 return mcmcstep.init(**kw) 1ab

655 

656 @staticmethod 1ab

657 def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw): 1ab

658 if isinstance(seed, jax.Array) and jnp.issubdtype( 1ab

659 seed.dtype, jax.dtypes.prng_key 

660 ): 

661 key = seed.copy() 1ab

662 # copy because the inner loop in run_mcmc will donate the buffer 

663 else: 

664 key = jax.random.key(seed) 1ab

665 

666 kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1ab

667 if printevery is not None: 1ab

668 kw.update( 1ab

669 mcmcloop.make_print_callback(None if printevery == 1 else 1, printevery) 

670 ) 

671 if run_mcmc_kw is not None: 671 ↛ 672line 671 didn't jump to line 672 because the condition on line 671 was never true1ab

672 kw.update(run_mcmc_kw) 

673 

674 return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw) 1ab

675 

676 @staticmethod 1ab

677 def _extract_sigma( 1ab

678 burnin_trace: mcmcloop.BurninTrace, main_trace: mcmcloop.MainTrace 

679 ) -> Float32[Array, ' trace_length'] | None: 

680 if burnin_trace.sigma2 is None: 1ab

681 return None 1ab

682 else: 

683 assert main_trace.sigma2 is not None 1ab

684 return jnp.sqrt(jnp.concatenate([burnin_trace.sigma2, main_trace.sigma2])) 1ab

685 

686 def _predict(self, x): 1ab

687 return mcmcloop.evaluate_trace(self._main_trace, x) 1ab