Coverage for src/bartz/BART.py: 89%
279 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-07 22:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-07 22:47 +0000
1# bartz/src/bartz/BART.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement a class `gbart` that mimics the R BART package."""
27import math 1ab
28from collections.abc import Sequence 1ab
29from functools import cached_property 1ab
30from typing import Any, Literal, Protocol 1ab
32import jax 1ab
33import jax.numpy as jnp 1ab
34from equinox import Module, field 1ab
35from jax.scipy.special import ndtr 1ab
36from jaxtyping import ( 1ab
37 Array,
38 Bool,
39 Float,
40 Float32,
41 Int32,
42 Integer,
43 Key,
44 Real,
45 Shaped,
46 UInt,
47)
48from numpy import ndarray 1ab
50from bartz import mcmcloop, mcmcstep, prepcovars 1ab
51from bartz.jaxext.scipy.special import ndtri 1ab
52from bartz.jaxext.scipy.stats import invgamma 1ab
54FloatLike = float | Float[Any, ''] 1ab
57class DataFrame(Protocol): 1ab
58 """DataFrame duck-type for `gbart`.
60 Attributes
61 ----------
62 columns : Sequence[str]
63 The names of the columns.
64 """
66 columns: Sequence[str] 1ab
68 def to_numpy(self) -> ndarray: 1ab
69 """Convert the dataframe to a 2d numpy array with columns on the second axis."""
70 ...
73class Series(Protocol): 1ab
74 """Series duck-type for `gbart`.
76 Attributes
77 ----------
78 name : str | None
79 The name of the series.
80 """
82 name: str | None 1ab
84 def to_numpy(self) -> ndarray: 1ab
85 """Convert the series to a 1d numpy array."""
86 ...
89class gbart(Module): 1ab
90 R"""
91 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
93 Regress `y_train` on `x_train` with a latent mean function represented as
94 a sum of decision trees. The inference is carried out by sampling the
95 posterior distribution of the tree ensemble with an MCMC.
97 Parameters
98 ----------
99 x_train
100 The training predictors.
101 y_train
102 The training responses.
103 x_test
104 The test predictors.
105 type
106 The type of regression. 'wbart' for continuous regression, 'pbart' for
107 binary regression with probit link.
108 sparse
109 Whether to activate variable selection on the predictors as done in
110 [1]_.
111 theta
112 a
113 b
114 rho
115 Hyperparameters of the sparsity prior used for variable selection.
117 The prior distribution on the choice of predictor for each decision rule
118 is
120 .. math::
121 (s_1, \ldots, s_p) \sim
122 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
124 If `theta` is not specified, it's a priori distributed according to
126 .. math::
127 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
128 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
130 If not specified, `rho` is set to the number of predictors p. To tune
131 the prior, consider setting a lower `rho` to prefer more sparsity.
132 If setting `theta` directly, it should be in the ballpark of p or lower
133 as well.
134 xinfo
135 A matrix with the cutpoins to use to bin each predictor. If not
136 specified, it is generated automatically according to `usequants` and
137 `numcut`.
139 Each row shall contain a sorted list of cutpoints for a predictor. If
140 there are less cutpoints than the number of columns in the matrix,
141 fill the remaining cells with NaN.
143 `xinfo` shall be a matrix even if `x_train` is a dataframe.
144 usequants
145 Whether to use predictors quantiles instead of a uniform grid to bin
146 predictors. Ignored if `xinfo` is specified.
147 rm_const
148 How to treat predictors with no associated decision rules (i.e., there
149 are no available cutpoints for that predictor). If `True` (default),
150 they are ignored. If `False`, an error is raised if there are any. If
151 `None`, no check is performed, and the output of the MCMC may not make
152 sense if there are predictors without cutpoints. The option `None` is
153 provided only to allow jax tracing.
154 sigest
155 An estimate of the residual standard deviation on `y_train`, used to set
156 `lamda`. If not specified, it is estimated by linear regression (with
157 intercept, and without taking into account `w`). If `y_train` has less
158 than two elements, it is set to 1. If n <= p, it is set to the standard
159 deviation of `y_train`. Ignored if `lamda` is specified.
160 sigdf
161 The degrees of freedom of the scaled inverse-chisquared prior on the
162 noise variance.
163 sigquant
164 The quantile of the prior on the noise variance that shall match
165 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
166 k
167 The inverse scale of the prior standard deviation on the latent mean
168 function, relative to half the observed range of `y_train`. If `y_train`
169 has less than two elements, `k` is ignored and the scale is set to 1.
170 power
171 base
172 Parameters of the prior on tree node generation. The probability that a
173 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
174 power``.
175 lamda
176 The prior harmonic mean of the error variance. (The harmonic mean of x
177 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
178 `sigquant`.
179 tau_num
180 The numerator in the expression that determines the prior standard
181 deviation of leaves. If not specified, default to ``(max(y_train) -
182 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
183 continuous regression, and 3 for binary regression.
184 offset
185 The prior mean of the latent mean function. If not specified, it is set
186 to the mean of `y_train` for continuous regression, and to
187 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
188 `offset` is set to 0. With binary regression, if `y_train` is all
189 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
190 ``Phi^-1(n/(n+1))``, respectively.
191 w
192 Coefficients that rescale the error standard deviation on each
193 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
194 datapoints. Note: `w` is ignored in the automatic determination of
195 `sigest`, so either the weights should be O(1), or `sigest` should be
196 specified by the user.
197 ntree
198 The number of trees used to represent the latent mean function. By
199 default 200 for continuous regression and 50 for binary regression.
200 numcut
201 If `usequants` is `False`: the exact number of cutpoints used to bin the
202 predictors, ranging between the minimum and maximum observed values
203 (excluded).
205 If `usequants` is `True`: the maximum number of cutpoints to use for
206 binning the predictors. Each predictor is binned such that its
207 distribution in `x_train` is approximately uniform across bins. The
208 number of bins is at most the number of unique values appearing in
209 `x_train`, or ``numcut + 1``.
211 Before running the algorithm, the predictors are compressed to the
212 smallest integer type that fits the bin indices, so `numcut` is best set
213 to the maximum value of an unsigned integer type, like 255.
215 Ignored if `xinfo` is specified.
216 ndpost
217 The number of MCMC samples to save, after burn-in.
218 nskip
219 The number of initial MCMC samples to discard as burn-in.
220 keepevery
221 The thinning factor for the MCMC samples, after burn-in. By default, 1
222 for continuous regression and 10 for binary regression.
223 printevery
224 The number of iterations (including thinned-away ones) between each log
225 line. Set to `None` to disable logging.
227 `printevery` has a few unexpected side effects. On cpu, interrupting
228 with ^C halts the MCMC only on the next log. And the total number of
229 iterations is a multiple of `printevery`, so if ``nskip + keepevery *
230 ndpost`` is not a multiple of `printevery`, some of the last iterations
231 will not be saved.
232 seed
233 The seed for the random number generator.
234 maxdepth
235 The maximum depth of the trees. This is 1-based, so with the default
236 ``maxdepth=6``, the depths of the levels range from 0 to 5.
237 init_kw
238 Additional arguments passed to `bartz.mcmcstep.init`.
239 run_mcmc_kw
240 Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
242 Attributes
243 ----------
244 offset : Float32[Array, '']
245 The prior mean of the latent mean function.
246 sigest : Float32[Array, ''] | None
247 The estimated standard deviation of the error used to set `lamda`.
248 yhat_test : Float32[Array, 'ndpost m'] | None
249 The conditional posterior mean at `x_test` for each MCMC iteration.
251 Notes
252 -----
253 This interface imitates the function ``gbart`` from the R package `BART
254 <https://cran.r-project.org/package=BART>`_, but with these differences:
256 - If `x_train` and `x_test` are matrices, they have one predictor per row
257 instead of per column.
258 - If ``usequants=False``, R BART switches to quantiles anyway if there are
259 less predictor values than the required number of bins, while bartz
260 always follows the specification.
261 - Some functionality is missing.
262 - The error variance parameter is called `lamda` instead of `lambda`.
263 - There are some additional attributes, and some missing.
264 - The trees have a maximum depth.
265 - `rm_const` refers to predictors without decision rules instead of
266 predictors that are constant in `x_train`.
267 - If `rm_const=True` and some variables are dropped, the predictors
268 matrix/dataframe passed to `predict` should still include them.
270 References
271 ----------
272 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
273 High-Dimensional Prediction and Variable Selection”. In: Journal of the
274 American Statistical Association 113.522, pp. 626-636.
275 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
276 Bayesian additive regression trees," The Annals of Applied Statistics,
277 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
278 """
280 _main_trace: mcmcloop.MainTrace 1ab
281 _burnin_trace: mcmcloop.BurninTrace 1ab
282 _mcmc_state: mcmcstep.State 1ab
283 _splits: Real[Array, 'p max_num_splits'] 1ab
284 _x_train_fmt: Any = field(static=True) 1ab
286 ndpost: int = field(static=True) 1ab
287 offset: Float32[Array, ''] 1ab
288 sigest: Float32[Array, ''] | None = None 1ab
289 yhat_test: Float32[Array, 'ndpost m'] | None = None 1ab
291 def __init__( 1ab
292 self,
293 x_train: Real[Array, 'p n'] | DataFrame,
294 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
295 *,
296 x_test: Real[Array, 'p m'] | DataFrame | None = None,
297 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
298 sparse: bool = False,
299 theta: FloatLike | None = None,
300 a: FloatLike = 0.5,
301 b: FloatLike = 1.0,
302 rho: FloatLike | None = None,
303 xinfo: Float[Array, 'p n'] | None = None,
304 usequants: bool = False,
305 rm_const: bool | None = True,
306 sigest: FloatLike | None = None,
307 sigdf: FloatLike = 3.0,
308 sigquant: FloatLike = 0.9,
309 k: FloatLike = 2.0,
310 power: FloatLike = 2.0,
311 base: FloatLike = 0.95,
312 lamda: FloatLike | None = None,
313 tau_num: FloatLike | None = None,
314 offset: FloatLike | None = None,
315 w: Float[Array, ' n'] | None = None,
316 ntree: int | None = None,
317 numcut: int = 100,
318 ndpost: int = 1000,
319 nskip: int = 100,
320 keepevery: int | None = None,
321 printevery: int | None = 100,
322 seed: int | Key[Array, ''] = 0,
323 maxdepth: int = 6,
324 init_kw: dict | None = None,
325 run_mcmc_kw: dict | None = None,
326 ):
327 # check data and put it in the right format
328 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab
329 y_train = self._process_response_input(y_train) 1ab
330 self._check_same_length(x_train, y_train) 1ab
331 if w is not None: 1ab
332 w = self._process_response_input(w) 1ab
333 self._check_same_length(x_train, w) 1ab
335 # check data types are correct for continuous/binary regression
336 self._check_type_settings(y_train, type, w) 1ab
337 # from here onwards, the type is determined by y_train.dtype == bool
339 # set defaults that depend on type of regression
340 if ntree is None: 1ab
341 ntree = 50 if y_train.dtype == bool else 200 1ab
342 if keepevery is None: 1ab
343 keepevery = 10 if y_train.dtype == bool else 1 1ab
345 # process sparsity settings
346 theta, a, b, rho = self._process_sparsity_settings( 1ab
347 x_train, sparse, theta, a, b, rho
348 )
350 # process "standardization" settings
351 offset = self._process_offset_settings(y_train, offset) 1ab
352 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab
353 lamda, sigest = self._process_error_variance_settings( 1ab
354 x_train, y_train, sigest, sigdf, sigquant, lamda
355 )
357 # determine splits
358 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1ab
359 x_train = self._bin_predictors(x_train, splits) 1ab
361 # setup and run mcmc
362 initial_state = self._setup_mcmc( 1ab
363 x_train,
364 y_train,
365 offset,
366 w,
367 max_split,
368 lamda,
369 sigma_mu,
370 sigdf,
371 power,
372 base,
373 maxdepth,
374 ntree,
375 init_kw,
376 rm_const,
377 theta,
378 a,
379 b,
380 rho,
381 )
382 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab
383 initial_state,
384 ndpost,
385 nskip,
386 keepevery,
387 printevery,
388 seed,
389 run_mcmc_kw,
390 sparse,
391 )
393 # set public attributes
394 self.offset = final_state.offset # from the state because of buffer donation 1ab
395 self.ndpost = ndpost 1ab
396 self.sigest = sigest 1ab
398 # set private attributes
399 self._main_trace = main_trace 1ab
400 self._burnin_trace = burnin_trace 1ab
401 self._mcmc_state = final_state 1ab
402 self._splits = splits 1ab
403 self._x_train_fmt = x_train_fmt 1ab
405 # predict at test points
406 if x_test is not None: 1ab
407 self.yhat_test = self.predict(x_test) 1ab
409 @cached_property 1ab
410 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 1ab
411 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
412 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1ab
413 return None 1ab
414 else:
415 return ndtr(self.yhat_test) 1ab
417 @cached_property 1ab
418 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 1ab
419 """The marginal posterior probability of y being True at `x_test`."""
420 if self.prob_test is None: 1ab
421 return None 1ab
422 else:
423 return self.prob_test.mean(axis=0) 1ab
425 @cached_property 1ab
426 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 1ab
427 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
428 if self._mcmc_state.y.dtype == bool: 1ab
429 return ndtr(self.yhat_train) 1ab
430 else:
431 return None 1ab
433 @cached_property 1ab
434 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 1ab
435 """The marginal posterior probability of y being True at `x_train`."""
436 if self.prob_train is None: 1ab
437 return None 1ab
438 else:
439 return self.prob_train.mean(axis=0) 1ab
441 @cached_property 1ab
442 def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None: 1ab
443 """The standard deviation of the error, including burn-in samples."""
444 if self._burnin_trace.sigma2 is None: 1ab
445 return None 1ab
446 else:
447 assert self._main_trace.sigma2 is not None 1ab
448 return jnp.sqrt( 1ab
449 jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2])
450 )
452 @cached_property 1ab
453 def sigma_mean(self) -> Float32[Array, ''] | None: 1ab
454 """The mean of `sigma`, only over the post-burnin samples."""
455 if self.sigma is None: 1ab
456 return None 1ab
457 else:
458 return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0) 1ab
460 @cached_property 1ab
461 def varcount(self) -> Int32[Array, 'ndpost p']: 1ab
462 """Histogram of predictor usage for decision rules in the trees."""
463 return mcmcloop.compute_varcount( 1ab
464 self._mcmc_state.forest.max_split.size, self._main_trace
465 )
467 @cached_property 1ab
468 def varcount_mean(self) -> Float32[Array, ' p']: 1ab
469 """Average of `varcount` across MCMC iterations."""
470 return self.varcount.mean(axis=0) 1ab
472 @cached_property 1ab
473 def varprob(self) -> Float32[Array, 'ndpost p']: 1ab
474 """Posterior samples of the probability of choosing each predictor for a decision rule."""
475 varprob = self._main_trace.varprob 1ab
476 if varprob is None: 1ab
477 max_split = self._mcmc_state.forest.max_split 1ab
478 p = max_split.size 1ab
479 peff = jnp.count_nonzero(max_split) 1ab
480 varprob = jnp.where(max_split, 1 / peff, 0) 1ab
481 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1ab
482 return varprob 1ab
484 @cached_property 1ab
485 def varprob_mean(self) -> Float32[Array, ' p']: 1ab
486 """The marginal posterior probability of each predictor being chosen for a decision rule."""
487 return self.varprob.mean(axis=0) 1ab
489 @cached_property 1ab
490 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 1ab
491 """The marginal posterior mean at `x_test`.
493 Not defined with binary regression because it's error-prone, typically
494 the right thing to consider would be `prob_test_mean`.
495 """
496 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1ab
497 return None 1ab
498 else:
499 return self.yhat_test.mean(axis=0) 1ab
501 @cached_property 1ab
502 def yhat_train(self) -> Float32[Array, 'ndpost n']: 1ab
503 """The conditional posterior mean at `x_train` for each MCMC iteration."""
504 x_train = self._mcmc_state.X 1ab
505 return self._predict(x_train) 1ab
507 @cached_property 1ab
508 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 1ab
509 """The marginal posterior mean at `x_train`.
511 Not defined with binary regression because it's error-prone, typically
512 the right thing to consider would be `prob_train_mean`.
513 """
514 if self._mcmc_state.y.dtype == bool: 1ab
515 return None 1ab
516 else:
517 return self.yhat_train.mean(axis=0) 1ab
519 def predict( 1ab
520 self, x_test: Real[Array, 'p m'] | DataFrame
521 ) -> Float32[Array, 'ndpost m']:
522 """
523 Compute the posterior mean at `x_test` for each MCMC iteration.
525 Parameters
526 ----------
527 x_test
528 The test predictors.
530 Returns
531 -------
532 The conditional posterior mean at `x_test` for each MCMC iteration.
534 Raises
535 ------
536 ValueError
537 If `x_test` has a different format than `x_train`.
538 """
539 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab
540 if x_test_fmt != self._x_train_fmt: 1ab
541 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1ab
542 raise ValueError(msg) 1ab
543 x_test = self._bin_predictors(x_test, self._splits) 1ab
544 return self._predict(x_test) 1ab
546 @staticmethod 1ab
547 def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: 1ab
548 if hasattr(x, 'columns'): 1ab
549 fmt = dict(kind='dataframe', columns=x.columns) 1ab
550 x = x.to_numpy().T 1ab
551 else:
552 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab
553 x = jnp.asarray(x) 1ab
554 assert x.ndim == 2 1ab
555 return x, fmt 1ab
557 @staticmethod 1ab
558 def _process_response_input(y) -> Shaped[Array, ' n']: 1ab
559 if hasattr(y, 'to_numpy'): 1ab
560 y = y.to_numpy() 1ab
561 y = jnp.asarray(y) 1ab
562 assert y.ndim == 1 1ab
563 return y 1ab
565 @staticmethod 1ab
566 def _check_same_length(x1, x2): 1ab
567 get_length = lambda x: x.shape[-1] 1ab
568 assert get_length(x1) == get_length(x2) 1ab
570 @staticmethod 1ab
571 def _process_error_variance_settings( 1ab
572 x_train, y_train, sigest, sigdf, sigquant, lamda
573 ) -> tuple[Float32[Array, ''] | None, ...]:
574 if y_train.dtype == bool: 1ab
575 if sigest is not None: 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true1ab
576 msg = 'Let `sigest=None` for binary regression'
577 raise ValueError(msg)
578 if lamda is not None: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true1ab
579 msg = 'Let `lamda=None` for binary regression'
580 raise ValueError(msg)
581 return None, None 1ab
582 elif lamda is not None: 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true1ab
583 if sigest is not None:
584 msg = 'Let `sigest=None` if `lamda` is specified'
585 raise ValueError(msg)
586 return lamda, None
587 else:
588 if sigest is not None: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true1ab
589 sigest2 = jnp.square(sigest)
590 elif y_train.size < 2: 1ab
591 sigest2 = 1 1ab
592 elif y_train.size <= x_train.shape[0]: 1ab
593 sigest2 = jnp.var(y_train) 1ab
594 else:
595 x_centered = x_train.T - x_train.mean(axis=1) 1ab
596 y_centered = y_train - y_train.mean() 1ab
597 # centering is equivalent to adding an intercept column
598 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab
599 chisq = chisq.squeeze(0) 1ab
600 dof = len(y_train) - rank 1ab
601 sigest2 = chisq / dof 1ab
602 alpha = sigdf / 2 1ab
603 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1ab
604 invchi2rid = invchi2 * sigdf 1ab
605 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab
607 @staticmethod 1ab
608 def _check_type_settings(y_train, type, w): # noqa: A002 1ab
609 match type: 1ab
610 case 'wbart': 1ab
611 if y_train.dtype != jnp.float32: 611 ↛ 612line 611 didn't jump to line 612 because the condition on line 611 was never true1ab
612 msg = (
613 'Continuous regression requires y_train.dtype=float32,'
614 f' got {y_train.dtype=} instead.'
615 )
616 raise TypeError(msg) 1a
617 case 'pbart': 617 ↛ 627line 617 didn't jump to line 627 because the pattern on line 617 always matched1ab
618 if w is not None: 618 ↛ 619line 618 didn't jump to line 619 because the condition on line 618 was never true1ab
619 msg = 'Binary regression does not support weights, set `w=None`'
620 raise ValueError(msg)
621 if y_train.dtype != bool: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true1ab
622 msg = (
623 'Binary regression requires y_train.dtype=bool,'
624 f' got {y_train.dtype=} instead.'
625 )
626 raise TypeError(msg) 1a
627 case _:
628 msg = f'Invalid {type=}'
629 raise ValueError(msg)
631 @staticmethod 1ab
632 def _process_sparsity_settings( 1ab
633 x_train: Real[Array, 'p n'],
634 sparse: bool,
635 theta: FloatLike | None,
636 a: FloatLike,
637 b: FloatLike,
638 rho: FloatLike | None,
639 ) -> (
640 tuple[None, None, None, None]
641 | tuple[FloatLike, None, None, None]
642 | tuple[None, FloatLike, FloatLike, FloatLike]
643 ):
644 if not sparse: 1ab
645 return None, None, None, None 1ab
646 elif theta is not None: 1ab
647 return theta, None, None, None 1ab
648 else:
649 if rho is None: 649 ↛ 652line 649 didn't jump to line 652 because the condition on line 649 was always true1ab
650 p, _ = x_train.shape 1ab
651 rho = float(p) 1ab
652 return None, a, b, rho 1ab
654 @staticmethod 1ab
655 def _process_offset_settings( 1ab
656 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
657 offset: float | Float32[Any, ''] | None,
658 ) -> Float32[Array, '']:
659 if offset is not None: 659 ↛ 660line 659 didn't jump to line 660 because the condition on line 659 was never true1ab
660 return jnp.asarray(offset)
661 elif y_train.size < 1: 1ab
662 return jnp.array(0.0) 1ab
663 else:
664 mean = y_train.mean() 1ab
666 if y_train.dtype == bool: 1ab
667 bound = 1 / (1 + y_train.size) 1ab
668 mean = jnp.clip(mean, bound, 1 - bound) 1ab
669 return ndtri(mean) 1ab
670 else:
671 return mean 1ab
673 @staticmethod 1ab
674 def _process_leaf_sdev_settings( 1ab
675 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
676 k: float,
677 ntree: int,
678 tau_num: FloatLike | None,
679 ):
680 if tau_num is None: 680 ↛ 688line 680 didn't jump to line 688 because the condition on line 680 was always true1ab
681 if y_train.dtype == bool: 1ab
682 tau_num = 3.0 1ab
683 elif y_train.size < 2: 1ab
684 tau_num = 1.0 1ab
685 else:
686 tau_num = (y_train.max() - y_train.min()) / 2 1ab
688 return tau_num / (k * math.sqrt(ntree)) 1ab
690 @staticmethod 1ab
691 def _determine_splits( 1ab
692 x_train: Real[Array, 'p n'],
693 usequants: bool,
694 numcut: int,
695 xinfo: Float[Array, 'p n'] | None,
696 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
697 if xinfo is not None: 1ab
698 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1ab
699 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1ab
700 raise ValueError(msg) 1ab
701 return prepcovars.parse_xinfo(xinfo) 1ab
702 elif usequants: 1ab
703 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab
704 else:
705 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab
707 @staticmethod 1ab
708 def _bin_predictors(x, splits) -> UInt[Array, 'p n']: 1ab
709 return prepcovars.bin_predictors(x, splits) 1ab
711 @staticmethod 1ab
712 def _setup_mcmc( 1ab
713 x_train: Real[Array, 'p n'],
714 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
715 offset: Float32[Array, ''],
716 w: Float[Array, ' n'] | None,
717 max_split: UInt[Array, ' p'],
718 lamda: Float32[Array, ''] | None,
719 sigma_mu: FloatLike,
720 sigdf: FloatLike,
721 power: FloatLike,
722 base: FloatLike,
723 maxdepth: int,
724 ntree: int,
725 init_kw: dict[str, Any] | None,
726 rm_const: bool | None,
727 theta: FloatLike | None,
728 a: FloatLike | None,
729 b: FloatLike | None,
730 rho: FloatLike | None,
731 ):
732 depth = jnp.arange(maxdepth - 1) 1ab
733 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab
735 if y_train.dtype == bool: 1ab
736 sigma2_alpha = None 1ab
737 sigma2_beta = None 1ab
738 else:
739 sigma2_alpha = sigdf / 2 1ab
740 sigma2_beta = lamda * sigma2_alpha 1ab
742 kw = dict( 1ab
743 X=x_train,
744 # copy y_train because it's going to be donated in the mcmc loop
745 y=jnp.array(y_train),
746 offset=offset,
747 error_scale=w,
748 max_split=max_split,
749 num_trees=ntree,
750 p_nonterminal=p_nonterminal,
751 sigma_mu2=jnp.square(sigma_mu),
752 sigma2_alpha=sigma2_alpha,
753 sigma2_beta=sigma2_beta,
754 min_points_per_decision_node=10,
755 min_points_per_leaf=5,
756 theta=theta,
757 a=a,
758 b=b,
759 rho=rho,
760 )
762 if rm_const is None: 1ab
763 kw.update(filter_splitless_vars=False) 1ab
764 elif rm_const: 764 ↛ 767line 764 didn't jump to line 767 because the condition on line 764 was always true1ab
765 kw.update(filter_splitless_vars=True) 1ab
766 else:
767 n_empty = jnp.count_nonzero(max_split == 0)
768 if n_empty:
769 msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
770 raise ValueError(msg)
771 kw.update(filter_splitless_vars=False)
773 if init_kw is not None: 1ab
774 kw.update(init_kw) 1ab
776 return mcmcstep.init(**kw) 1ab
778 @staticmethod 1ab
779 def _run_mcmc( 1ab
780 mcmc_state: mcmcstep.State,
781 ndpost: int,
782 nskip: int,
783 keepevery: int,
784 printevery: int | None,
785 seed: int | Integer[Array, ''] | Key[Array, ''],
786 run_mcmc_kw: dict | None,
787 sparse: bool,
788 ):
789 # prepare random generator seed
790 if isinstance(seed, jax.Array) and jnp.issubdtype( 1ab
791 seed.dtype, jax.dtypes.prng_key
792 ):
793 key = seed.copy() 1ab
794 # copy because the inner loop in run_mcmc will donate the buffer
795 else:
796 key = jax.random.key(seed) 1ab
798 # prepare arguments
799 kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1ab
800 kw.update( 1ab
801 mcmcloop.make_default_callback(
802 dot_every=None if printevery is None or printevery == 1 else 1,
803 report_every=printevery,
804 sparse_on_at=nskip // 2 if sparse else None,
805 )
806 )
807 if run_mcmc_kw is not None: 1ab
808 kw.update(run_mcmc_kw) 1ab
810 return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw) 1ab
812 def _predict(self, x): 1ab
813 return mcmcloop.evaluate_trace(self._main_trace, x) 1ab