Coverage for src/bartz/BART.py: 88%
258 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/BART.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement a class `gbart` that mimics the R BART package."""
27import math 1ab
28from collections.abc import Sequence 1ab
29from functools import cached_property 1ab
30from typing import Any, Literal, Protocol 1ab
32import jax 1ab
33import jax.numpy as jnp 1ab
34from equinox import Module, field 1ab
35from jax.scipy.special import ndtr 1ab
36from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, Shaped, UInt 1ab
37from numpy import ndarray 1ab
39from bartz import mcmcloop, mcmcstep, prepcovars 1ab
40from bartz.jaxext.scipy.special import ndtri 1ab
41from bartz.jaxext.scipy.stats import invgamma 1ab
43FloatLike = float | Float[Any, ''] 1ab
46class DataFrame(Protocol): 1ab
47 """DataFrame duck-type for `gbart`.
49 Attributes
50 ----------
51 columns : Sequence[str]
52 The names of the columns.
53 """
55 columns: Sequence[str] 1ab
57 def to_numpy(self) -> ndarray: 1ab
58 """Convert the dataframe to a 2d numpy array with columns on the second axis."""
59 ...
62class Series(Protocol): 1ab
63 """Series duck-type for `gbart`.
65 Attributes
66 ----------
67 name : str | None
68 The name of the series.
69 """
71 name: str | None 1ab
73 def to_numpy(self) -> ndarray: 1ab
74 """Convert the series to a 1d numpy array."""
75 ...
78class gbart(Module): 1ab
79 """
80 Nonparametric regression with Bayesian Additive Regression Trees (BART).
82 Regress `y_train` on `x_train` with a latent mean function represented as
83 a sum of decision trees. The inference is carried out by sampling the
84 posterior distribution of the tree ensemble with an MCMC.
86 Parameters
87 ----------
88 x_train
89 The training predictors.
90 y_train
91 The training responses.
92 x_test
93 The test predictors.
94 type
95 The type of regression. 'wbart' for continuous regression, 'pbart' for
96 binary regression with probit link.
97 xinfo
98 A matrix with the cutpoins to use to bin each predictor. If not
99 specified, it is generated automatically according to `usequants` and
100 `numcut`.
102 Each row shall contain a sorted list of cutpoints for a predictor. If
103 there are less cutpoints than the number of columns in the matrix,
104 fill the remaining cells with NaN.
106 `xinfo` shall be a matrix even if `x_train` is a dataframe.
107 usequants
108 Whether to use predictors quantiles instead of a uniform grid to bin
109 predictors. Ignored if `xinfo` is specified.
110 rm_const
111 How to treat predictors with no associated decision rules (i.e., there
112 are no available cutpoints for that predictor). If `True` (default),
113 they are ignored. If `False`, an error is raised if there are any. If
114 `None`, no check is performed, and the output of the MCMC may not make
115 sense if there are predictors without cutpoints. The option `None` is
116 provided only to allow jax tracing.
117 sigest
118 An estimate of the residual standard deviation on `y_train`, used to set
119 `lamda`. If not specified, it is estimated by linear regression (with
120 intercept, and without taking into account `w`). If `y_train` has less
121 than two elements, it is set to 1. If n <= p, it is set to the standard
122 deviation of `y_train`. Ignored if `lamda` is specified.
123 sigdf
124 The degrees of freedom of the scaled inverse-chisquared prior on the
125 noise variance.
126 sigquant
127 The quantile of the prior on the noise variance that shall match
128 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
129 k
130 The inverse scale of the prior standard deviation on the latent mean
131 function, relative to half the observed range of `y_train`. If `y_train`
132 has less than two elements, `k` is ignored and the scale is set to 1.
133 power
134 base
135 Parameters of the prior on tree node generation. The probability that a
136 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
137 power``.
138 lamda
139 The prior harmonic mean of the error variance. (The harmonic mean of x
140 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
141 `sigquant`.
142 tau_num
143 The numerator in the expression that determines the prior standard
144 deviation of leaves. If not specified, default to ``(max(y_train) -
145 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
146 continuous regression, and 3 for binary regression.
147 offset
148 The prior mean of the latent mean function. If not specified, it is set
149 to the mean of `y_train` for continuous regression, and to
150 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
151 `offset` is set to 0. With binary regression, if `y_train` is all
152 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
153 ``Phi^-1(n/(n+1))``, respectively.
154 w
155 Coefficients that rescale the error standard deviation on each
156 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
157 datapoints. Note: `w` is ignored in the automatic determination of
158 `sigest`, so either the weights should be O(1), or `sigest` should be
159 specified by the user.
160 ntree
161 The number of trees used to represent the latent mean function. By
162 default 200 for continuous regression and 50 for binary regression.
163 numcut
164 If `usequants` is `False`: the exact number of cutpoints used to bin the
165 predictors, ranging between the minimum and maximum observed values
166 (excluded).
168 If `usequants` is `True`: the maximum number of cutpoints to use for
169 binning the predictors. Each predictor is binned such that its
170 distribution in `x_train` is approximately uniform across bins. The
171 number of bins is at most the number of unique values appearing in
172 `x_train`, or ``numcut + 1``.
174 Before running the algorithm, the predictors are compressed to the
175 smallest integer type that fits the bin indices, so `numcut` is best set
176 to the maximum value of an unsigned integer type, like 255.
178 Ignored if `xinfo` is specified.
179 ndpost
180 The number of MCMC samples to save, after burn-in.
181 nskip
182 The number of initial MCMC samples to discard as burn-in.
183 keepevery
184 The thinning factor for the MCMC samples, after burn-in. By default, 1
185 for continuous regression and 10 for binary regression.
186 printevery
187 The number of iterations (including thinned-away ones) between each log
188 line. Set to `None` to disable logging.
190 `printevery` has a few unexpected side effects. On cpu, interrupting
191 with ^C halts the MCMC only on the next log. And the total number of
192 iterations is a multiple of `printevery`, so if ``nskip + keepevery *
193 ndpost`` is not a multiple of `printevery`, some of the last iterations
194 will not be saved.
195 seed
196 The seed for the random number generator.
197 maxdepth
198 The maximum depth of the trees. This is 1-based, so with the default
199 ``maxdepth=6``, the depths of the levels range from 0 to 5.
200 init_kw
201 Additional arguments passed to `bartz.mcmcstep.init`.
202 run_mcmc_kw
203 Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
205 Attributes
206 ----------
207 offset : Float32[Array, '']
208 The prior mean of the latent mean function.
209 sigest : Float32[Array, ''] | None
210 The estimated standard deviation of the error used to set `lamda`.
211 sigma : Float32[Array, 'nskip+ndpost'] | None
212 The standard deviation of the error, including burn-in samples.
213 yhat_test : Float32[Array, 'ndpost m'] | None
214 The conditional posterior mean at `x_test` for each MCMC iteration.
216 Notes
217 -----
218 This interface imitates the function ``gbart`` from the R package `BART
219 <https://cran.r-project.org/package=BART>`_, but with these differences:
221 - If `x_train` and `x_test` are matrices, they have one predictor per row
222 instead of per column.
223 - If ``usequants=False``, R BART switches to quantiles anyway if there are
224 less predictor values than the required number of bins, while bartz
225 always follows the specification.
226 - The error variance parameter is called `lamda` instead of `lambda`.
227 - Some functionality is missing (e.g., variable selection).
228 - There are some additional attributes, and some missing.
229 - The trees have a maximum depth.
230 - `rm_const` refers to predictors without decision rules instead of
231 predictors that are constant in `x_train`.
233 """
235 _main_trace: mcmcloop.MainTrace 1ab
236 _mcmc_state: mcmcstep.State 1ab
237 _splits: Real[Array, 'p max_num_splits'] 1ab
238 _x_train_fmt: Any = field(static=True) 1ab
240 ndpost: int = field(static=True) 1ab
241 offset: Float32[Array, ''] 1ab
242 sigma: Float32[Array, ' nskip+ndpost'] | None = None 1ab
243 sigest: Float32[Array, ''] | None = None 1ab
244 yhat_test: Float32[Array, 'ndpost m'] | None = None 1ab
246 def __init__( 1ab
247 self,
248 x_train: Real[Array, 'p n'] | DataFrame,
249 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
250 *,
251 x_test: Real[Array, 'p m'] | DataFrame | None = None,
252 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
253 xinfo: Float[Array, 'p n'] | None = None,
254 usequants: bool = False,
255 rm_const: bool | None = True,
256 sigest: FloatLike | None = None,
257 sigdf: FloatLike = 3.0,
258 sigquant: FloatLike = 0.9,
259 k: FloatLike = 2.0,
260 power: FloatLike = 2.0,
261 base: FloatLike = 0.95,
262 lamda: FloatLike | None = None,
263 tau_num: FloatLike | None = None,
264 offset: FloatLike | None = None,
265 w: Float[Array, ' n'] | None = None,
266 ntree: int | None = None,
267 numcut: int = 100,
268 ndpost: int = 1000,
269 nskip: int = 100,
270 keepevery: int | None = None,
271 printevery: int | None = 100,
272 seed: int | Key[Array, ''] = 0,
273 maxdepth: int = 6,
274 init_kw: dict | None = None,
275 run_mcmc_kw: dict | None = None,
276 ):
277 # check data and put it in the right format
278 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab
279 y_train, _ = self._process_response_input(y_train) 1ab
280 self._check_same_length(x_train, y_train) 1ab
281 if w is not None: 1ab
282 w, _ = self._process_response_input(w) 1ab
283 self._check_same_length(x_train, w) 1ab
285 # check data types are correct for continuous/binary regression
286 self._check_type_settings(y_train, type, w) 1ab
287 # from here onwards, the type is determined by y_train.dtype == bool
289 # set defaults that depend on type of regression
290 if ntree is None: 1ab
291 ntree = 50 if y_train.dtype == bool else 200 1ab
292 if keepevery is None: 1ab
293 keepevery = 10 if y_train.dtype == bool else 1 1ab
295 # process "standardization" settings
296 offset = self._process_offset_settings(y_train, offset) 1ab
297 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab
298 lamda, sigest = self._process_error_variance_settings( 1ab
299 x_train, y_train, sigest, sigdf, sigquant, lamda
300 )
302 # determine splits
303 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1ab
304 x_train = self._bin_predictors(x_train, splits) 1ab
306 # setup and run mcmc
307 initial_state = self._setup_mcmc( 1ab
308 x_train,
309 y_train,
310 offset,
311 w,
312 max_split,
313 lamda,
314 sigma_mu,
315 sigdf,
316 power,
317 base,
318 maxdepth,
319 ntree,
320 init_kw,
321 rm_const,
322 )
323 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab
324 initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
325 )
327 # set public attributes
328 self.offset = final_state.offset # from the state because of buffer donation 1ab
329 self.ndpost = ndpost 1ab
330 self.sigest = sigest 1ab
331 self.sigma = self._extract_sigma(burnin_trace, main_trace) 1ab
333 # set private attributes
334 self._main_trace = main_trace 1ab
335 self._mcmc_state = final_state 1ab
336 self._splits = splits 1ab
337 self._x_train_fmt = x_train_fmt 1ab
339 # predict at test points
340 if x_test is not None: 1ab
341 self.yhat_test = self.predict(x_test) 1ab
343 @cached_property 1ab
344 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 1ab
345 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
346 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1ab
347 return None 1ab
348 else:
349 return ndtr(self.yhat_test) 1ab
351 @cached_property 1ab
352 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 1ab
353 """The marginal posterior probability of y being True at `x_test`."""
354 if self.prob_test is None: 1ab
355 return None 1ab
356 else:
357 return self.prob_test.mean(axis=0) 1ab
359 @cached_property 1ab
360 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 1ab
361 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
362 if self._mcmc_state.y.dtype == bool: 1ab
363 return ndtr(self.yhat_train) 1ab
364 else:
365 return None 1ab
367 @cached_property 1ab
368 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 1ab
369 """The marginal posterior probability of y being True at `x_train`."""
370 if self.prob_train is None: 1ab
371 return None 1ab
372 else:
373 return self.prob_train.mean(axis=0) 1ab
375 @cached_property 1ab
376 def sigma_mean(self) -> Float32[Array, ''] | None: 1ab
377 """The mean of `sigma`, only over the post-burnin samples."""
378 if self.sigma is None: 1ab
379 return None 1ab
380 else:
381 return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0) 1ab
383 @cached_property 1ab
384 def varcount(self) -> Int32[Array, 'ndpost p']: 1ab
385 """Histogram of predictor usage for decision rules in the trees."""
386 return mcmcloop.compute_varcount( 1ab
387 self._mcmc_state.forest.max_split.size, self._main_trace
388 )
390 @cached_property 1ab
391 def varcount_mean(self) -> Float32[Array, ' p']: 1ab
392 """Average of `varcount` across MCMC iterations."""
393 return self.varcount.mean(axis=0) 1ab
395 @cached_property 1ab
396 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 1ab
397 """The marginal posterior mean at `x_test`.
399 Not defined with binary regression because it's error-prone, typically
400 the right thing to consider would be `prob_test_mean`.
401 """
402 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1ab
403 return None 1ab
404 else:
405 return self.yhat_test.mean(axis=0) 1ab
407 @cached_property 1ab
408 def yhat_train(self) -> Float32[Array, 'ndpost n']: 1ab
409 """The conditional posterior mean at `x_train` for each MCMC iteration."""
410 x_train = self._mcmc_state.X 1ab
411 return self._predict(x_train) 1ab
413 @cached_property 1ab
414 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 1ab
415 """The marginal posterior mean at `x_train`.
417 Not defined with binary regression because it's error-prone, typically
418 the right thing to consider would be `prob_train_mean`.
419 """
420 if self._mcmc_state.y.dtype == bool: 1ab
421 return None 1ab
422 else:
423 return self.yhat_train.mean(axis=0) 1ab
425 def predict( 1ab
426 self, x_test: Real[Array, 'p m'] | DataFrame
427 ) -> Float32[Array, 'ndpost m']:
428 """
429 Compute the posterior mean at `x_test` for each MCMC iteration.
431 Parameters
432 ----------
433 x_test
434 The test predictors.
436 Returns
437 -------
438 The conditional posterior mean at `x_test` for each MCMC iteration.
440 Raises
441 ------
442 ValueError
443 If `x_test` has a different format than `x_train`.
444 """
445 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab
446 if x_test_fmt != self._x_train_fmt: 1ab
447 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1ab
448 raise ValueError(msg) 1ab
449 x_test = self._bin_predictors(x_test, self._splits) 1ab
450 return self._predict(x_test) 1ab
452 @staticmethod 1ab
453 def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]: 1ab
454 if hasattr(x, 'columns'): 1ab
455 fmt = dict(kind='dataframe', columns=x.columns) 1ab
456 x = x.to_numpy().T 1ab
457 else:
458 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab
459 x = jnp.asarray(x) 1ab
460 assert x.ndim == 2 1ab
461 return x, fmt 1ab
463 @staticmethod 1ab
464 def _process_response_input(y) -> tuple[Shaped[Array, ' n'], Any]: 1ab
465 if hasattr(y, 'to_numpy'): 1ab
466 fmt = dict(kind='series', name=y.name) 1ab
467 y = y.to_numpy() 1ab
468 else:
469 fmt = dict(kind='array') 1ab
470 y = jnp.asarray(y) 1ab
471 assert y.ndim == 1 1ab
472 return y, fmt 1ab
474 @staticmethod 1ab
475 def _check_same_length(x1, x2): 1ab
476 get_length = lambda x: x.shape[-1] 1ab
477 assert get_length(x1) == get_length(x2) 1ab
479 @staticmethod 1ab
480 def _process_error_variance_settings( 1ab
481 x_train, y_train, sigest, sigdf, sigquant, lamda
482 ) -> tuple[Float32[Array, ''] | None, ...]:
483 if y_train.dtype == bool: 1ab
484 if sigest is not None: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true1ab
485 msg = 'Let `sigest=None` for binary regression'
486 raise ValueError(msg)
487 if lamda is not None: 487 ↛ 488line 487 didn't jump to line 488 because the condition on line 487 was never true1ab
488 msg = 'Let `lamda=None` for binary regression'
489 raise ValueError(msg)
490 return None, None 1ab
491 elif lamda is not None: 491 ↛ 492line 491 didn't jump to line 492 because the condition on line 491 was never true1ab
492 if sigest is not None:
493 msg = 'Let `sigest=None` if `lamda` is specified'
494 raise ValueError(msg)
495 return lamda, None
496 else:
497 if sigest is not None: 497 ↛ 498line 497 didn't jump to line 498 because the condition on line 497 was never true1ab
498 sigest2 = jnp.square(sigest)
499 elif y_train.size < 2: 1ab
500 sigest2 = 1 1ab
501 elif y_train.size <= x_train.shape[0]: 1ab
502 sigest2 = jnp.var(y_train) 1ab
503 else:
504 x_centered = x_train.T - x_train.mean(axis=1) 1ab
505 y_centered = y_train - y_train.mean() 1ab
506 # centering is equivalent to adding an intercept column
507 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab
508 chisq = chisq.squeeze(0) 1ab
509 dof = len(y_train) - rank 1ab
510 sigest2 = chisq / dof 1ab
511 alpha = sigdf / 2 1ab
512 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1ab
513 invchi2rid = invchi2 * sigdf 1ab
514 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab
516 @staticmethod 1ab
517 def _check_type_settings(y_train, type, w): # noqa: A002 1ab
518 match type: 1ab
519 case 'wbart': 1ab
520 if y_train.dtype != jnp.float32: 520 ↛ 521line 520 didn't jump to line 521 because the condition on line 520 was never true1ab
521 msg = (
522 'Continuous regression requires y_train.dtype=float32,'
523 f' got {y_train.dtype=} instead.'
524 )
525 raise TypeError(msg) 1a
526 case 'pbart': 526 ↛ 536line 526 didn't jump to line 536 because the pattern on line 526 always matched1ab
527 if w is not None: 527 ↛ 528line 527 didn't jump to line 528 because the condition on line 527 was never true1ab
528 msg = 'Binary regression does not support weights, set `w=None`'
529 raise ValueError(msg)
530 if y_train.dtype != bool: 530 ↛ 531line 530 didn't jump to line 531 because the condition on line 530 was never true1ab
531 msg = (
532 'Binary regression requires y_train.dtype=bool,'
533 f' got {y_train.dtype=} instead.'
534 )
535 raise TypeError(msg) 1a
536 case _:
537 msg = f'Invalid {type=}'
538 raise ValueError(msg)
540 @staticmethod 1ab
541 def _process_offset_settings( 1ab
542 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
543 offset: float | Float32[Any, ''] | None,
544 ) -> Float32[Array, '']:
545 if offset is not None: 545 ↛ 546line 545 didn't jump to line 546 because the condition on line 545 was never true1ab
546 return jnp.asarray(offset)
547 elif y_train.size < 1: 1ab
548 return jnp.array(0.0) 1ab
549 else:
550 mean = y_train.mean() 1ab
552 if y_train.dtype == bool: 1ab
553 bound = 1 / (1 + y_train.size) 1ab
554 mean = jnp.clip(mean, bound, 1 - bound) 1ab
555 return ndtri(mean) 1ab
556 else:
557 return mean 1ab
559 @staticmethod 1ab
560 def _process_leaf_sdev_settings( 1ab
561 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
562 k: float,
563 ntree: int,
564 tau_num: FloatLike | None,
565 ):
566 if tau_num is None: 566 ↛ 574line 566 didn't jump to line 574 because the condition on line 566 was always true1ab
567 if y_train.dtype == bool: 1ab
568 tau_num = 3.0 1ab
569 elif y_train.size < 2: 1ab
570 tau_num = 1.0 1ab
571 else:
572 tau_num = (y_train.max() - y_train.min()) / 2 1ab
574 return tau_num / (k * math.sqrt(ntree)) 1ab
576 @staticmethod 1ab
577 def _determine_splits( 1ab
578 x_train: Real[Array, 'p n'],
579 usequants: bool,
580 numcut: int,
581 xinfo: Float[Array, 'p n'] | None,
582 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
583 if xinfo is not None: 1ab
584 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1ab
585 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1ab
586 raise ValueError(msg) 1ab
587 return prepcovars.parse_xinfo(xinfo) 1ab
588 elif usequants: 1ab
589 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab
590 else:
591 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab
593 @staticmethod 1ab
594 def _bin_predictors(x, splits) -> UInt[Array, 'p n']: 1ab
595 return prepcovars.bin_predictors(x, splits) 1ab
597 @staticmethod 1ab
598 def _setup_mcmc( 1ab
599 x_train: Real[Array, 'p n'],
600 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
601 offset: Float32[Array, ''],
602 w: Float[Array, ' n'] | None,
603 max_split: UInt[Array, ' p'],
604 lamda: Float32[Array, ''] | None,
605 sigma_mu: FloatLike,
606 sigdf: FloatLike,
607 power: FloatLike,
608 base: FloatLike,
609 maxdepth: int,
610 ntree: int,
611 init_kw: dict[str, Any] | None,
612 rm_const: bool | None,
613 ):
614 depth = jnp.arange(maxdepth - 1) 1ab
615 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab
617 if y_train.dtype == bool: 1ab
618 sigma2_alpha = None 1ab
619 sigma2_beta = None 1ab
620 else:
621 sigma2_alpha = sigdf / 2 1ab
622 sigma2_beta = lamda * sigma2_alpha 1ab
624 kw = dict( 1ab
625 X=x_train,
626 # copy y_train because it's going to be donated in the mcmc loop
627 y=jnp.array(y_train),
628 offset=offset,
629 error_scale=w,
630 max_split=max_split,
631 num_trees=ntree,
632 p_nonterminal=p_nonterminal,
633 sigma_mu2=jnp.square(sigma_mu),
634 sigma2_alpha=sigma2_alpha,
635 sigma2_beta=sigma2_beta,
636 min_points_per_decision_node=10,
637 min_points_per_leaf=5,
638 )
640 if rm_const is None: 1ab
641 kw.update(filter_splitless_vars=False) 1ab
642 elif rm_const: 642 ↛ 645line 642 didn't jump to line 645 because the condition on line 642 was always true1ab
643 kw.update(filter_splitless_vars=True) 1ab
644 else:
645 n_empty = jnp.count_nonzero(max_split == 0)
646 if n_empty:
647 msg = f'There are {n_empty} predictors without decision rules'
648 raise ValueError(msg)
649 kw.update(filter_splitless_vars=False)
651 if init_kw is not None: 1ab
652 kw.update(init_kw) 1ab
654 return mcmcstep.init(**kw) 1ab
656 @staticmethod 1ab
657 def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw): 1ab
658 if isinstance(seed, jax.Array) and jnp.issubdtype( 1ab
659 seed.dtype, jax.dtypes.prng_key
660 ):
661 key = seed.copy() 1ab
662 # copy because the inner loop in run_mcmc will donate the buffer
663 else:
664 key = jax.random.key(seed) 1ab
666 kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1ab
667 if printevery is not None: 1ab
668 kw.update( 1ab
669 mcmcloop.make_print_callback(None if printevery == 1 else 1, printevery)
670 )
671 if run_mcmc_kw is not None: 671 ↛ 672line 671 didn't jump to line 672 because the condition on line 671 was never true1ab
672 kw.update(run_mcmc_kw)
674 return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw) 1ab
676 @staticmethod 1ab
677 def _extract_sigma( 1ab
678 burnin_trace: mcmcloop.BurninTrace, main_trace: mcmcloop.MainTrace
679 ) -> Float32[Array, ' trace_length'] | None:
680 if burnin_trace.sigma2 is None: 1ab
681 return None 1ab
682 else:
683 assert main_trace.sigma2 is not None 1ab
684 return jnp.sqrt(jnp.concatenate([burnin_trace.sigma2, main_trace.sigma2])) 1ab
686 def _predict(self, x): 1ab
687 return mcmcloop.evaluate_trace(self._main_trace, x) 1ab