Coverage for src/bartz/BART.py: 91%
336 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 16:09 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 16:09 +0000
1# bartz/src/bartz/BART.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART package."""
27import math 1ab
28from collections.abc import Sequence 1ab
29from functools import cached_property, partial 1ab
30from typing import Any, Literal, Protocol 1ab
32import jax 1ab
33import jax.numpy as jnp 1ab
34from equinox import Module, field 1ab
35from jax.scipy.special import ndtr 1ab
36from jax.tree import map_with_path 1ab
37from jaxtyping import ( 1ab
38 Array,
39 Bool,
40 Float,
41 Float32,
42 Int32,
43 Integer,
44 Key,
45 Real,
46 Shaped,
47 UInt,
48)
49from numpy import ndarray 1ab
51from bartz import mcmcloop, mcmcstep, prepcovars 1ab
52from bartz.jaxext.scipy.special import ndtri 1ab
53from bartz.jaxext.scipy.stats import invgamma 1ab
55FloatLike = float | Float[Any, ''] 1ab
58class DataFrame(Protocol): 1ab
59 """DataFrame duck-type for `mc_gbart`.
61 Attributes
62 ----------
63 columns : Sequence[str]
64 The names of the columns.
65 """
67 columns: Sequence[str] 1ab
69 def to_numpy(self) -> ndarray: 1ab
70 """Convert the dataframe to a 2d numpy array with columns on the second axis."""
71 ...
74class Series(Protocol): 1ab
75 """Series duck-type for `mc_gbart`.
77 Attributes
78 ----------
79 name : str | None
80 The name of the series.
81 """
83 name: str | None 1ab
85 def to_numpy(self) -> ndarray: 1ab
86 """Convert the series to a 1d numpy array."""
87 ...
90class mc_gbart(Module): 1ab
91 R"""
92 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
94 Regress `y_train` on `x_train` with a latent mean function represented as
95 a sum of decision trees. The inference is carried out by sampling the
96 posterior distribution of the tree ensemble with an MCMC.
98 Parameters
99 ----------
100 x_train
101 The training predictors.
102 y_train
103 The training responses.
104 x_test
105 The test predictors.
106 type
107 The type of regression. 'wbart' for continuous regression, 'pbart' for
108 binary regression with probit link.
109 sparse
110 Whether to activate variable selection on the predictors as done in
111 [1]_.
112 theta
113 a
114 b
115 rho
116 Hyperparameters of the sparsity prior used for variable selection.
118 The prior distribution on the choice of predictor for each decision rule
119 is
121 .. math::
122 (s_1, \ldots, s_p) \sim
123 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
125 If `theta` is not specified, it's a priori distributed according to
127 .. math::
128 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
129 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
131 If not specified, `rho` is set to the number of predictors p. To tune
132 the prior, consider setting a lower `rho` to prefer more sparsity.
133 If setting `theta` directly, it should be in the ballpark of p or lower
134 as well.
135 xinfo
136 A matrix with the cutpoins to use to bin each predictor. If not
137 specified, it is generated automatically according to `usequants` and
138 `numcut`.
140 Each row shall contain a sorted list of cutpoints for a predictor. If
141 there are less cutpoints than the number of columns in the matrix,
142 fill the remaining cells with NaN.
144 `xinfo` shall be a matrix even if `x_train` is a dataframe.
145 usequants
146 Whether to use predictors quantiles instead of a uniform grid to bin
147 predictors. Ignored if `xinfo` is specified.
148 rm_const
149 How to treat predictors with no associated decision rules (i.e., there
150 are no available cutpoints for that predictor). If `True` (default),
151 they are ignored. If `False`, an error is raised if there are any. If
152 `None`, no check is performed, and the output of the MCMC may not make
153 sense if there are predictors without cutpoints. The option `None` is
154 provided only to allow jax tracing.
155 sigest
156 An estimate of the residual standard deviation on `y_train`, used to set
157 `lamda`. If not specified, it is estimated by linear regression (with
158 intercept, and without taking into account `w`). If `y_train` has less
159 than two elements, it is set to 1. If n <= p, it is set to the standard
160 deviation of `y_train`. Ignored if `lamda` is specified.
161 sigdf
162 The degrees of freedom of the scaled inverse-chisquared prior on the
163 noise variance.
164 sigquant
165 The quantile of the prior on the noise variance that shall match
166 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
167 k
168 The inverse scale of the prior standard deviation on the latent mean
169 function, relative to half the observed range of `y_train`. If `y_train`
170 has less than two elements, `k` is ignored and the scale is set to 1.
171 power
172 base
173 Parameters of the prior on tree node generation. The probability that a
174 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
175 power``.
176 lamda
177 The prior harmonic mean of the error variance. (The harmonic mean of x
178 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
179 `sigquant`.
180 tau_num
181 The numerator in the expression that determines the prior standard
182 deviation of leaves. If not specified, default to ``(max(y_train) -
183 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
184 continuous regression, and 3 for binary regression.
185 offset
186 The prior mean of the latent mean function. If not specified, it is set
187 to the mean of `y_train` for continuous regression, and to
188 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
189 `offset` is set to 0. With binary regression, if `y_train` is all
190 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
191 ``Phi^-1(n/(n+1))``, respectively.
192 w
193 Coefficients that rescale the error standard deviation on each
194 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
195 datapoints. Note: `w` is ignored in the automatic determination of
196 `sigest`, so either the weights should be O(1), or `sigest` should be
197 specified by the user.
198 ntree
199 The number of trees used to represent the latent mean function. By
200 default 200 for continuous regression and 50 for binary regression.
201 numcut
202 If `usequants` is `False`: the exact number of cutpoints used to bin the
203 predictors, ranging between the minimum and maximum observed values
204 (excluded).
206 If `usequants` is `True`: the maximum number of cutpoints to use for
207 binning the predictors. Each predictor is binned such that its
208 distribution in `x_train` is approximately uniform across bins. The
209 number of bins is at most the number of unique values appearing in
210 `x_train`, or ``numcut + 1``.
212 Before running the algorithm, the predictors are compressed to the
213 smallest integer type that fits the bin indices, so `numcut` is best set
214 to the maximum value of an unsigned integer type, like 255.
216 Ignored if `xinfo` is specified.
217 ndpost
218 The number of MCMC samples to save, after burn-in. `ndpost` is the
219 total number of samples across all chains. `ndpost` is rounded up to the
220 first multiple of `mc_cores`.
221 nskip
222 The number of initial MCMC samples to discard as burn-in. This number
223 of samples is discarded from each chain.
224 keepevery
225 The thinning factor for the MCMC samples, after burn-in. By default, 1
226 for continuous regression and 10 for binary regression.
227 printevery
228 The number of iterations (including thinned-away ones) between each log
229 line. Set to `None` to disable logging.
231 `printevery` has a few unexpected side effects. On cpu, interrupting
232 with ^C halts the MCMC only on the next log. And the total number of
233 iterations is a multiple of `printevery`, so if ``nskip + keepevery *
234 ndpost`` is not a multiple of `printevery`, some of the last iterations
235 will not be saved.
236 mc_cores
237 The number of independent MCMC chains.
238 seed
239 The seed for the random number generator.
240 maxdepth
241 The maximum depth of the trees. This is 1-based, so with the default
242 ``maxdepth=6``, the depths of the levels range from 0 to 5.
243 init_kw
244 Additional arguments passed to `bartz.mcmcstep.init`.
245 run_mcmc_kw
246 Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
248 Attributes
249 ----------
250 offset : Float32[Array, '']
251 The prior mean of the latent mean function.
252 sigest : Float32[Array, ''] | None
253 The estimated standard deviation of the error used to set `lamda`.
254 yhat_test : Float32[Array, 'ndpost m'] | None
255 The conditional posterior mean at `x_test` for each MCMC iteration.
257 Notes
258 -----
259 This interface imitates the function ``mc_gbart`` from the R package `BART
260 <https://cran.r-project.org/package=BART>`_, but with these differences:
262 - If `x_train` and `x_test` are matrices, they have one predictor per row
263 instead of per column.
264 - If ``usequants=False``, R BART switches to quantiles anyway if there are
265 less predictor values than the required number of bins, while bartz
266 always follows the specification.
267 - Some functionality is missing.
268 - The error variance parameter is called `lamda` instead of `lambda`.
269 - There are some additional attributes, and some missing.
270 - The trees have a maximum depth.
271 - `rm_const` refers to predictors without decision rules instead of
272 predictors that are constant in `x_train`.
273 - If `rm_const=True` and some variables are dropped, the predictors
274 matrix/dataframe passed to `predict` should still include them.
276 References
277 ----------
278 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
279 High-Dimensional Prediction and Variable Selection”. In: Journal of the
280 American Statistical Association 113.522, pp. 626-636.
281 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
282 Bayesian additive regression trees," The Annals of Applied Statistics,
283 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
284 """
286 _main_trace: mcmcloop.MainTrace 1ab
287 _burnin_trace: mcmcloop.BurninTrace 1ab
288 _mcmc_state: mcmcstep.State 1ab
289 _splits: Real[Array, 'p max_num_splits'] 1ab
290 _x_train_fmt: Any = field(static=True) 1ab
292 ndpost: int = field(static=True) 1ab
293 offset: Float32[Array, ''] 1ab
294 sigest: Float32[Array, ''] | None = None 1ab
295 yhat_test: Float32[Array, 'ndpost m'] | None = None 1ab
297 def __init__( 1ab
298 self,
299 x_train: Real[Array, 'p n'] | DataFrame,
300 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
301 *,
302 x_test: Real[Array, 'p m'] | DataFrame | None = None,
303 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
304 sparse: bool = False,
305 theta: FloatLike | None = None,
306 a: FloatLike = 0.5,
307 b: FloatLike = 1.0,
308 rho: FloatLike | None = None,
309 xinfo: Float[Array, 'p n'] | None = None,
310 usequants: bool = False,
311 rm_const: bool | None = True,
312 sigest: FloatLike | None = None,
313 sigdf: FloatLike = 3.0,
314 sigquant: FloatLike = 0.9,
315 k: FloatLike = 2.0,
316 power: FloatLike = 2.0,
317 base: FloatLike = 0.95,
318 lamda: FloatLike | None = None,
319 tau_num: FloatLike | None = None,
320 offset: FloatLike | None = None,
321 w: Float[Array, ' n'] | None = None,
322 ntree: int | None = None,
323 numcut: int = 100,
324 ndpost: int = 1000,
325 nskip: int = 100,
326 keepevery: int | None = None,
327 printevery: int | None = 100,
328 mc_cores: int = 2,
329 seed: int | Key[Array, ''] = 0,
330 maxdepth: int = 6,
331 init_kw: dict | None = None,
332 run_mcmc_kw: dict | None = None,
333 ):
334 # check data and put it in the right format
335 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab
336 y_train = self._process_response_input(y_train) 1ab
337 self._check_same_length(x_train, y_train) 1ab
338 if w is not None: 1ab
339 w = self._process_response_input(w) 1ab
340 self._check_same_length(x_train, w) 1ab
342 # check data types are correct for continuous/binary regression
343 self._check_type_settings(y_train, type, w) 1ab
344 # from here onwards, the type is determined by y_train.dtype == bool
346 # set defaults that depend on type of regression
347 if ntree is None: 1ab
348 ntree = 50 if y_train.dtype == bool else 200 1ab
349 if keepevery is None: 1ab
350 keepevery = 10 if y_train.dtype == bool else 1 1ab
352 # process sparsity settings
353 theta, a, b, rho = self._process_sparsity_settings( 1ab
354 x_train, sparse, theta, a, b, rho
355 )
357 # process "standardization" settings
358 offset = self._process_offset_settings(y_train, offset) 1ab
359 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab
360 lamda, sigest = self._process_error_variance_settings( 1ab
361 x_train, y_train, sigest, sigdf, sigquant, lamda
362 )
364 # determine splits
365 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1ab
366 x_train = self._bin_predictors(x_train, splits) 1ab
368 # setup and run mcmc
369 initial_state = self._setup_mcmc( 1ab
370 x_train,
371 y_train,
372 offset,
373 w,
374 max_split,
375 lamda,
376 sigma_mu,
377 sigdf,
378 power,
379 base,
380 maxdepth,
381 ntree,
382 init_kw,
383 rm_const,
384 theta,
385 a,
386 b,
387 rho,
388 )
389 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab
390 initial_state,
391 mc_cores,
392 ndpost,
393 nskip,
394 keepevery,
395 printevery,
396 seed,
397 run_mcmc_kw,
398 sparse,
399 )
401 # set public attributes
402 self.offset = final_state.offset # from the state because of buffer donation 1ab
403 self.ndpost = main_trace.grow_prop_count.size 1ab
404 self.sigest = sigest 1ab
406 # set private attributes
407 self._main_trace = main_trace 1ab
408 self._burnin_trace = burnin_trace 1ab
409 self._mcmc_state = final_state 1ab
410 self._splits = splits 1ab
411 self._x_train_fmt = x_train_fmt 1ab
413 # predict at test points
414 if x_test is not None: 1ab
415 self.yhat_test = self.predict(x_test) 1ab
417 @cached_property 1ab
418 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 1ab
419 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
420 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1ab
421 return None 1ab
422 else:
423 return ndtr(self.yhat_test) 1ab
425 @cached_property 1ab
426 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 1ab
427 """The marginal posterior probability of y being True at `x_test`."""
428 if self.prob_test is None: 1ab
429 return None 1ab
430 else:
431 return self.prob_test.mean(axis=0) 1ab
433 @cached_property 1ab
434 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 1ab
435 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
436 if self._mcmc_state.y.dtype == bool: 1ab
437 return ndtr(self.yhat_train) 1ab
438 else:
439 return None 1ab
441 @cached_property 1ab
442 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 1ab
443 """The marginal posterior probability of y being True at `x_train`."""
444 if self.prob_train is None: 1ab
445 return None 1ab
446 else:
447 return self.prob_train.mean(axis=0) 1ab
449 @cached_property 1ab
450 def sigma( 1ab
451 self,
452 ) -> (
453 Float32[Array, ' nskip+ndpost']
454 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
455 | None
456 ):
457 """The standard deviation of the error, including burn-in samples."""
458 if self._burnin_trace.sigma2 is None: 1ab
459 return None 1ab
460 assert self._main_trace.sigma2 is not None 1ab
461 sigma = jnp.sqrt( 1ab
462 jnp.concatenate(
463 [self._burnin_trace.sigma2, self._main_trace.sigma2], axis=1
464 )
465 )
466 sigma = sigma.T 1ab
467 _, mc_cores = sigma.shape 1ab
468 if mc_cores == 1: 1ab
469 sigma = sigma.squeeze(1) 1ab
470 return sigma 1ab
472 @cached_property 1ab
473 def sigma_mean(self) -> Float32[Array, ''] | None: 1ab
474 """The mean of `sigma`, only over the post-burnin samples."""
475 if self.sigma is None: 1ab
476 return None 1ab
477 _, nskip = self._burnin_trace.grow_prop_count.shape 1ab
478 return self.sigma[nskip:, ...].mean() 1ab
480 @cached_property 1ab
481 def varcount(self) -> Int32[Array, 'ndpost p']: 1ab
482 """Histogram of predictor usage for decision rules in the trees."""
483 return self._compute_varcount_multichain_flattened( 1ab
484 self._mcmc_state.forest.max_split.size, self._main_trace
485 )
487 @staticmethod 1ab
488 @partial(jax.vmap, in_axes=(None, 0)) 1ab
489 def _compute_varcount_multichain( 1ab
490 p: int, main_trace: mcmcloop.MainTrace
491 ) -> Int32[Array, 'mc_cores ndpost/mc_cores p']:
492 return mcmcloop.compute_varcount(p, main_trace) 1ab
494 @classmethod 1ab
495 @partial(jax.jit, static_argnums=(0, 1)) 1ab
496 def _compute_varcount_multichain_flattened( 1ab
497 cls, p: int, main_trace: mcmcloop.MainTrace
498 ) -> Int32[Array, 'ndpost p']:
499 return cls._compute_varcount_multichain(p, main_trace).reshape(-1, p) 1ab
501 @cached_property 1ab
502 def varcount_mean(self) -> Float32[Array, ' p']: 1ab
503 """Average of `varcount` across MCMC iterations."""
504 return self.varcount.mean(axis=0) 1ab
506 @cached_property 1ab
507 def varprob(self) -> Float32[Array, 'ndpost p']: 1ab
508 """Posterior samples of the probability of choosing each predictor for a decision rule."""
509 max_split = self._mcmc_state.forest.max_split 1ab
510 p = max_split.size 1ab
511 varprob = self._main_trace.varprob 1ab
512 if varprob is None: 1ab
513 peff = jnp.count_nonzero(max_split) 1ab
514 varprob = jnp.where(max_split, 1 / peff, 0) 1ab
515 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1ab
516 else:
517 varprob = varprob.reshape(-1, p) 1ab
518 return varprob 1ab
520 @cached_property 1ab
521 def varprob_mean(self) -> Float32[Array, ' p']: 1ab
522 """The marginal posterior probability of each predictor being chosen for a decision rule."""
523 return self.varprob.mean(axis=0) 1ab
525 @cached_property 1ab
526 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 1ab
527 """The marginal posterior mean at `x_test`.
529 Not defined with binary regression because it's error-prone, typically
530 the right thing to consider would be `prob_test_mean`.
531 """
532 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1ab
533 return None 1ab
534 else:
535 return self.yhat_test.mean(axis=0) 1ab
537 @cached_property 1ab
538 def yhat_train(self) -> Float32[Array, 'ndpost n']: 1ab
539 """The conditional posterior mean at `x_train` for each MCMC iteration."""
540 x_train = self._mcmc_state.X 1ab
541 return self._predict(x_train) 1ab
543 @cached_property 1ab
544 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 1ab
545 """The marginal posterior mean at `x_train`.
547 Not defined with binary regression because it's error-prone, typically
548 the right thing to consider would be `prob_train_mean`.
549 """
550 if self._mcmc_state.y.dtype == bool: 1ab
551 return None 1ab
552 else:
553 return self.yhat_train.mean(axis=0) 1ab
555 def predict( 1ab
556 self, x_test: Real[Array, 'p m'] | DataFrame
557 ) -> Float32[Array, 'ndpost m']:
558 """
559 Compute the posterior mean at `x_test` for each MCMC iteration.
561 Parameters
562 ----------
563 x_test
564 The test predictors.
566 Returns
567 -------
568 The conditional posterior mean at `x_test` for each MCMC iteration.
570 Raises
571 ------
572 ValueError
573 If `x_test` has a different format than `x_train`.
574 """
575 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab
576 if x_test_fmt != self._x_train_fmt: 1ab
577 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1ab
578 raise ValueError(msg) 1ab
579 x_test = self._bin_predictors(x_test, self._splits) 1ab
580 return self._predict(x_test) 1ab
582 @staticmethod 1ab
583 def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: 1ab
584 if hasattr(x, 'columns'): 1ab
585 fmt = dict(kind='dataframe', columns=x.columns) 1ab
586 x = x.to_numpy().T 1ab
587 else:
588 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab
589 x = jnp.asarray(x) 1ab
590 assert x.ndim == 2 1ab
591 return x, fmt 1ab
593 @staticmethod 1ab
594 def _process_response_input(y) -> Shaped[Array, ' n']: 1ab
595 if hasattr(y, 'to_numpy'): 1ab
596 y = y.to_numpy() 1ab
597 y = jnp.asarray(y) 1ab
598 assert y.ndim == 1 1ab
599 return y 1ab
601 @staticmethod 1ab
602 def _check_same_length(x1, x2): 1ab
603 get_length = lambda x: x.shape[-1] 1ab
604 assert get_length(x1) == get_length(x2) 1ab
606 @staticmethod 1ab
607 def _process_error_variance_settings( 1ab
608 x_train, y_train, sigest, sigdf, sigquant, lamda
609 ) -> tuple[Float32[Array, ''] | None, ...]:
610 if y_train.dtype == bool: 1ab
611 if sigest is not None: 611 ↛ 612line 611 didn't jump to line 612 because the condition on line 611 was never true1ab
612 msg = 'Let `sigest=None` for binary regression'
613 raise ValueError(msg)
614 if lamda is not None: 614 ↛ 615line 614 didn't jump to line 615 because the condition on line 614 was never true1ab
615 msg = 'Let `lamda=None` for binary regression'
616 raise ValueError(msg)
617 return None, None 1ab
618 elif lamda is not None: 618 ↛ 619line 618 didn't jump to line 619 because the condition on line 618 was never true1ab
619 if sigest is not None:
620 msg = 'Let `sigest=None` if `lamda` is specified'
621 raise ValueError(msg)
622 return lamda, None
623 else:
624 if sigest is not None: 624 ↛ 625line 624 didn't jump to line 625 because the condition on line 624 was never true1ab
625 sigest2 = jnp.square(sigest)
626 elif y_train.size < 2: 1ab
627 sigest2 = 1 1ab
628 elif y_train.size <= x_train.shape[0]: 1ab
629 sigest2 = jnp.var(y_train) 1ab
630 else:
631 x_centered = x_train.T - x_train.mean(axis=1) 1ab
632 y_centered = y_train - y_train.mean() 1ab
633 # centering is equivalent to adding an intercept column
634 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab
635 chisq = chisq.squeeze(0) 1ab
636 dof = len(y_train) - rank 1ab
637 sigest2 = chisq / dof 1ab
638 alpha = sigdf / 2 1ab
639 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1ab
640 invchi2rid = invchi2 * sigdf 1ab
641 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab
643 @staticmethod 1ab
644 def _check_type_settings(y_train, type, w): # noqa: A002 1ab
645 match type: 1ab
646 case 'wbart': 1ab
647 if y_train.dtype != jnp.float32: 647 ↛ 648line 647 didn't jump to line 648 because the condition on line 647 was never true1ab
648 msg = (
649 'Continuous regression requires y_train.dtype=float32,'
650 f' got {y_train.dtype=} instead.'
651 )
652 raise TypeError(msg) 1a
653 case 'pbart': 653 ↛ 663line 653 didn't jump to line 663 because the pattern on line 653 always matched1ab
654 if w is not None: 654 ↛ 655line 654 didn't jump to line 655 because the condition on line 654 was never true1ab
655 msg = 'Binary regression does not support weights, set `w=None`'
656 raise ValueError(msg)
657 if y_train.dtype != bool: 657 ↛ 658line 657 didn't jump to line 658 because the condition on line 657 was never true1ab
658 msg = (
659 'Binary regression requires y_train.dtype=bool,'
660 f' got {y_train.dtype=} instead.'
661 )
662 raise TypeError(msg) 1a
663 case _:
664 msg = f'Invalid {type=}'
665 raise ValueError(msg)
667 @staticmethod 1ab
668 def _process_sparsity_settings( 1ab
669 x_train: Real[Array, 'p n'],
670 sparse: bool,
671 theta: FloatLike | None,
672 a: FloatLike,
673 b: FloatLike,
674 rho: FloatLike | None,
675 ) -> (
676 tuple[None, None, None, None]
677 | tuple[FloatLike, None, None, None]
678 | tuple[None, FloatLike, FloatLike, FloatLike]
679 ):
680 if not sparse: 1ab
681 return None, None, None, None 1ab
682 elif theta is not None: 1ab
683 return theta, None, None, None 1ab
684 else:
685 if rho is None: 685 ↛ 688line 685 didn't jump to line 688 because the condition on line 685 was always true1ab
686 p, _ = x_train.shape 1ab
687 rho = float(p) 1ab
688 return None, a, b, rho 1ab
690 @staticmethod 1ab
691 def _process_offset_settings( 1ab
692 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
693 offset: float | Float32[Any, ''] | None,
694 ) -> Float32[Array, '']:
695 if offset is not None: 695 ↛ 696line 695 didn't jump to line 696 because the condition on line 695 was never true1ab
696 return jnp.asarray(offset)
697 elif y_train.size < 1: 1ab
698 return jnp.array(0.0) 1ab
699 else:
700 mean = y_train.mean() 1ab
702 if y_train.dtype == bool: 1ab
703 bound = 1 / (1 + y_train.size) 1ab
704 mean = jnp.clip(mean, bound, 1 - bound) 1ab
705 return ndtri(mean) 1ab
706 else:
707 return mean 1ab
709 @staticmethod 1ab
710 def _process_leaf_sdev_settings( 1ab
711 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
712 k: float,
713 ntree: int,
714 tau_num: FloatLike | None,
715 ):
716 if tau_num is None: 716 ↛ 724line 716 didn't jump to line 724 because the condition on line 716 was always true1ab
717 if y_train.dtype == bool: 1ab
718 tau_num = 3.0 1ab
719 elif y_train.size < 2: 1ab
720 tau_num = 1.0 1ab
721 else:
722 tau_num = (y_train.max() - y_train.min()) / 2 1ab
724 return tau_num / (k * math.sqrt(ntree)) 1ab
726 @staticmethod 1ab
727 def _determine_splits( 1ab
728 x_train: Real[Array, 'p n'],
729 usequants: bool,
730 numcut: int,
731 xinfo: Float[Array, 'p n'] | None,
732 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
733 if xinfo is not None: 1ab
734 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1ab
735 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1ab
736 raise ValueError(msg) 1ab
737 return prepcovars.parse_xinfo(xinfo) 1ab
738 elif usequants: 1ab
739 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab
740 else:
741 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab
743 @staticmethod 1ab
744 def _bin_predictors(x, splits) -> UInt[Array, 'p n']: 1ab
745 return prepcovars.bin_predictors(x, splits) 1ab
747 @staticmethod 1ab
748 def _setup_mcmc( 1ab
749 x_train: Real[Array, 'p n'],
750 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
751 offset: Float32[Array, ''],
752 w: Float[Array, ' n'] | None,
753 max_split: UInt[Array, ' p'],
754 lamda: Float32[Array, ''] | None,
755 sigma_mu: FloatLike,
756 sigdf: FloatLike,
757 power: FloatLike,
758 base: FloatLike,
759 maxdepth: int,
760 ntree: int,
761 init_kw: dict[str, Any] | None,
762 rm_const: bool | None,
763 theta: FloatLike | None,
764 a: FloatLike | None,
765 b: FloatLike | None,
766 rho: FloatLike | None,
767 ):
768 depth = jnp.arange(maxdepth - 1) 1ab
769 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab
771 if y_train.dtype == bool: 1ab
772 sigma2_alpha = None 1ab
773 sigma2_beta = None 1ab
774 else:
775 sigma2_alpha = sigdf / 2 1ab
776 sigma2_beta = lamda * sigma2_alpha 1ab
778 kw = dict( 1ab
779 X=x_train,
780 # copy y_train because it's going to be donated in the mcmc loop
781 y=jnp.array(y_train),
782 offset=offset,
783 error_scale=w,
784 max_split=max_split,
785 num_trees=ntree,
786 p_nonterminal=p_nonterminal,
787 sigma_mu2=jnp.square(sigma_mu),
788 sigma2_alpha=sigma2_alpha,
789 sigma2_beta=sigma2_beta,
790 min_points_per_decision_node=10,
791 min_points_per_leaf=5,
792 theta=theta,
793 a=a,
794 b=b,
795 rho=rho,
796 )
798 if rm_const is None: 1ab
799 kw.update(filter_splitless_vars=False) 1ab
800 elif rm_const: 800 ↛ 803line 800 didn't jump to line 803 because the condition on line 800 was always true1ab
801 kw.update(filter_splitless_vars=True) 1ab
802 else:
803 n_empty = jnp.count_nonzero(max_split == 0)
804 if n_empty:
805 msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
806 raise ValueError(msg)
807 kw.update(filter_splitless_vars=False)
809 if init_kw is not None: 1ab
810 kw.update(init_kw) 1ab
812 return mcmcstep.init(**kw) 1ab
814 @classmethod 1ab
815 def _run_mcmc( 1ab
816 cls,
817 mcmc_state: mcmcstep.State,
818 mc_cores: int,
819 ndpost: int,
820 nskip: int,
821 keepevery: int,
822 printevery: int | None,
823 seed: int | Integer[Array, ''] | Key[Array, ''],
824 run_mcmc_kw: dict | None,
825 sparse: bool,
826 ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]:
827 # prepare random generator seed
828 if isinstance(seed, jax.Array) and jnp.issubdtype( 1ab
829 seed.dtype, jax.dtypes.prng_key
830 ):
831 key = seed 1ab
832 else:
833 key = jax.random.key(seed) 1ab
835 # round up ndpost
836 ndpost = mc_cores * (ndpost // mc_cores + bool(ndpost % mc_cores)) 1ab
838 # prepare arguments
839 kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1ab
840 kw.update( 1ab
841 mcmcloop.make_default_callback(
842 dot_every=None if printevery is None or printevery == 1 else 1,
843 report_every=printevery,
844 sparse_on_at=nskip // 2 if sparse else None,
845 )
846 )
847 if run_mcmc_kw is not None: 1ab
848 kw.update(run_mcmc_kw) 1ab
850 if mc_cores == 1: 1ab
851 return cls._single_run_mcmc(key, mcmc_state, ndpost, **kw) 1ab
852 else:
853 keys = jax.random.split(key, mc_cores) 1ab
854 return cls._vmapped_run_mcmc(keys, mcmc_state, ndpost // mc_cores, **kw) 1ab
856 @classmethod 1ab
857 def _single_run_mcmc( 1ab
858 cls, key: Key[Array, ''], bart: mcmcstep.State, *args, **kwargs
859 ):
860 out = mcmcloop.run_mcmc(key, bart, *args, **kwargs) 1ab
861 axes = cls._vmap_axes_for_state(bart) 1ab
862 return jax.vmap(lambda x: x, in_axes=None, out_axes=(axes, 0, 0), axis_size=1)( 1ab
863 out
864 )
866 @classmethod 1ab
867 def _vmapped_run_mcmc( 1ab
868 cls, keys: Key[Array, ' mc_cores'], bart: mcmcstep.State, *args, **kwargs
869 ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]:
870 bart_axes = cls._vmap_axes_for_state(bart) 1ab
872 barts = jax.vmap( 1ab
873 lambda x: x, in_axes=None, out_axes=bart_axes, axis_size=keys.size
874 )(bart)
876 @partial(jax.vmap, in_axes=(0, bart_axes), out_axes=(bart_axes, 0, 0)) 1ab
877 def _partial_vmapped_run_mcmc(key, bart): 1ab
878 return mcmcloop.run_mcmc(key, bart, *args, **kwargs) 1ab
880 return _partial_vmapped_run_mcmc(keys, barts) 1ab
882 @staticmethod 1ab
883 def _vmap_axes_for_state(state: mcmcstep.State) -> mcmcstep.State: 1ab
884 def choose_vmap_index(path, _) -> Literal[0, None]: 1ab
885 no_vmap_attrs = ( 1ab
886 '.X',
887 '.y',
888 '.offset',
889 '.prec_scale',
890 '.sigma2_alpha',
891 '.sigma2_beta',
892 '.forest.max_split',
893 '.forest.blocked_vars',
894 '.forest.p_nonterminal',
895 '.forest.p_propose_grow',
896 '.forest.min_points_per_decision_node',
897 '.forest.min_points_per_leaf',
898 '.forest.sigma_mu2',
899 '.forest.a',
900 '.forest.b',
901 '.forest.rho',
902 )
903 str_path = ''.join(map(str, path)) 1ab
904 if str_path in no_vmap_attrs: 1ab
905 return None 1ab
906 else:
907 return 0 1ab
909 return map_with_path(choose_vmap_index, state) 1ab
911 def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: 1ab
912 return self._evaluate_chains_flattened(self._main_trace, x) 1ab
914 @classmethod 1ab
915 @partial(jax.jit, static_argnums=(0,)) 1ab
916 def _evaluate_chains_flattened( 1ab
917 cls, trace: mcmcloop.MainTrace, x: UInt[Array, 'p m']
918 ) -> Float32[Array, 'ndpost m']:
919 out = cls._evaluate_chains(trace, x) 1ab
920 mc_cores, ndpost_per_chain, m = out.shape 1ab
921 return out.reshape(mc_cores * ndpost_per_chain, m) 1ab
923 @staticmethod 1ab
924 @partial(jax.vmap, in_axes=(0, None)) 1ab
925 def _evaluate_chains( 1ab
926 trace: mcmcloop.MainTrace, x: UInt[Array, 'p m']
927 ) -> Float32[Array, 'mc_cores ndpost/mc_cores m']:
928 return mcmcloop.evaluate_trace(trace, x) 1ab
931class gbart(mc_gbart): 1ab
932 """Subclass of `mc_gbart` that forces `mc_cores=1`."""
934 def __init__(self, *args, **kwargs): 1ab
935 if 'mc_cores' in kwargs: 1ab
936 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 1ab
937 raise TypeError(msg) 1ab
938 kwargs.update(mc_cores=1) 1ab
939 super().__init__(*args, **kwargs) 1ab