Coverage for src/bartz/BART.py: 78%
236 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1# bartz/src/bartz/BART.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement a user interface that mimics the R BART package."""
27import functools 1ab
28import math 1ab
29from typing import Any, Literal 1ab
31import jax 1ab
32import jax.numpy as jnp 1ab
33from jax.scipy.special import ndtri 1ab
34from jaxtyping import Array, Bool, Float, Float32 1ab
36from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars 1ab
38FloatLike = float | Float[Any, ''] 1ab
41class gbart: 1ab
42 """
43 Nonparametric regression with Bayesian Additive Regression Trees (BART).
45 Regress `y_train` on `x_train` with a latent mean function represented as
46 a sum of decision trees. The inference is carried out by sampling the
47 posterior distribution of the tree ensemble with an MCMC.
49 Parameters
50 ----------
51 x_train : array (p, n) or DataFrame
52 The training predictors.
53 y_train : array (n,) or Series
54 The training responses.
55 x_test : array (p, m) or DataFrame, optional
56 The test predictors.
57 type
58 The type of regression. 'wbart' for continuous regression, 'pbart' for
59 binary regression with probit link.
60 usequants : bool, default False
61 Whether to use predictors quantiles instead of a uniform grid to bin
62 predictors.
63 sigest : float, optional
64 An estimate of the residual standard deviation on `y_train`, used to set
65 `lamda`. If not specified, it is estimated by linear regression (with
66 intercept, and without taking into account `w`). If `y_train` has less
67 than two elements, it is set to 1. If n <= p, it is set to the standard
68 deviation of `y_train`. Ignored if `lamda` is specified.
69 sigdf : int, default 3
70 The degrees of freedom of the scaled inverse-chisquared prior on the
71 noise variance.
72 sigquant : float, default 0.9
73 The quantile of the prior on the noise variance that shall match
74 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
75 k : float, default 2
76 The inverse scale of the prior standard deviation on the latent mean
77 function, relative to half the observed range of `y_train`. If `y_train`
78 has less than two elements, `k` is ignored and the scale is set to 1.
79 power : float, default 2
80 base : float, default 0.95
81 Parameters of the prior on tree node generation. The probability that a
82 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
83 power``.
84 lamda
85 The prior harmonic mean of the error variance. (The harmonic mean of x
86 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
87 `sigquant`.
88 tau_num
89 The numerator in the expression that determines the prior standard
90 deviation of leaves. If not specified, default to ``(max(y_train) -
91 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
92 continuous regression, and 3 for binary regression.
93 offset
94 The prior mean of the latent mean function. If not specified, it is set
95 to the mean of `y_train` for continuous regression, and to
96 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
97 `offset` is set to 0.
98 w : array (n,), optional
99 Coefficients that rescale the error standard deviation on each
100 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
101 datapoints. Note: `w` is ignored in the automatic determination of
102 `sigest`, so either the weights should be O(1), or `sigest` should be
103 specified by the user.
104 ntree : int, default 200
105 The number of trees used to represent the latent mean function.
106 numcut : int, default 255
107 If `usequants` is `False`: the exact number of cutpoints used to bin the
108 predictors, ranging between the minimum and maximum observed values
109 (excluded).
111 If `usequants` is `True`: the maximum number of cutpoints to use for
112 binning the predictors. Each predictor is binned such that its
113 distribution in `x_train` is approximately uniform across bins. The
114 number of bins is at most the number of unique values appearing in
115 `x_train`, or ``numcut + 1``.
117 Before running the algorithm, the predictors are compressed to the
118 smallest integer type that fits the bin indices, so `numcut` is best set
119 to the maximum value of an unsigned integer type.
120 ndpost : int, default 1000
121 The number of MCMC samples to save, after burn-in.
122 nskip : int, default 100
123 The number of initial MCMC samples to discard as burn-in.
124 keepevery : int, default 1
125 The thinning factor for the MCMC samples, after burn-in.
126 printevery : int or None, default 100
127 The number of iterations (including thinned-away ones) between each log
128 line. Set to `None` to disable logging.
130 `printevery` has a few unexpected side effects. On cpu, interrupting
131 with ^C halts the MCMC only on the next log. And the total number of
132 iterations is a multiple of `printevery`, so if ``nskip + keepevery *
133 ndpost`` is not a multiple of `printevery`, some of the last iterations
134 will not be saved.
135 seed : int or jax random key, default 0
136 The seed for the random number generator.
137 maxdepth : int, default 6
138 The maximum depth of the trees. This is 1-based, so with the default
139 ``maxdepth=6``, the depths of the levels range from 0 to 5.
140 init_kw : dict
141 Additional arguments passed to `mcmcstep.init`.
142 run_mcmc_kw : dict
143 Additional arguments passed to `mcmcloop.run_mcmc`.
145 Attributes
146 ----------
147 yhat_train : array (ndpost, n)
148 The conditional posterior mean at `x_train` for each MCMC iteration.
149 yhat_train_mean : array (n,)
150 The marginal posterior mean at `x_train`.
151 yhat_test : array (ndpost, m)
152 The conditional posterior mean at `x_test` for each MCMC iteration.
153 yhat_test_mean : array (m,)
154 The marginal posterior mean at `x_test`.
155 sigma : array (ndpost,)
156 The standard deviation of the error.
157 first_sigma : array (nskip,)
158 The standard deviation of the error in the burn-in phase.
159 offset : float
160 The prior mean of the latent mean function.
161 sigest : float or None
162 The estimated standard deviation of the error used to set `lamda`.
164 Notes
165 -----
166 This interface imitates the function ``gbart`` from the R package `BART
167 <https://cran.r-project.org/package=BART>`_, but with these differences:
169 - If `x_train` and `x_test` are matrices, they have one predictor per row
170 instead of per column.
171 - If `type` is not specified, it is determined solely based on the data type
172 of `y_train`, and not on whether it contains only two unique values.
173 - If ``usequants=False``, R BART switches to quantiles anyway if there are
174 less predictor values than the required number of bins, while bartz
175 always follows the specification.
176 - The error variance parameter is called `lamda` instead of `lambda`.
177 - `rm_const` is always `False`.
178 - The default `numcut` is 255 instead of 100.
179 - A lot of functionality is missing (e.g., variable selection).
180 - There are some additional attributes, and some missing.
181 - The trees have a maximum depth.
183 """
185 def __init__( 1ab
186 self,
187 x_train,
188 y_train,
189 *,
190 x_test=None,
191 type: Literal['wbart', 'pbart'] = 'wbart',
192 usequants=False,
193 sigest=None,
194 sigdf=3,
195 sigquant=0.9,
196 k=2,
197 power=2,
198 base=0.95,
199 lamda: FloatLike | None = None,
200 tau_num: FloatLike | None = None,
201 offset: FloatLike | None = None,
202 w=None,
203 ntree=200,
204 numcut=255,
205 ndpost=1000,
206 nskip=100,
207 keepevery=1,
208 printevery=100,
209 seed=0,
210 maxdepth=6,
211 init_kw=None,
212 run_mcmc_kw=None,
213 ):
214 x_train, x_train_fmt = self._process_predictor_input(x_train) 1ab
215 y_train, _ = self._process_response_input(y_train) 1ab
216 self._check_same_length(x_train, y_train) 1ab
217 if w is not None: 1ab
218 w, _ = self._process_response_input(w) 1ab
219 self._check_same_length(x_train, w) 1ab
221 y_train = self._process_type_settings(y_train, type, w) 1ab
222 # from here onwards, the type is determined by y_train.dtype == bool
223 offset = self._process_offset_settings(y_train, offset) 1ab
224 sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num) 1ab
225 lamda, sigest = self._process_error_variance_settings( 1ab
226 x_train, y_train, sigest, sigdf, sigquant, lamda
227 )
229 splits, max_split = self._determine_splits(x_train, usequants, numcut) 1ab
230 x_train = self._bin_predictors(x_train, splits) 1ab
232 mcmc_state = self._setup_mcmc( 1ab
233 x_train,
234 y_train,
235 offset,
236 w,
237 max_split,
238 lamda,
239 sigma_mu,
240 sigdf,
241 power,
242 base,
243 maxdepth,
244 ntree,
245 init_kw,
246 )
247 final_state, burnin_trace, main_trace = self._run_mcmc( 1ab
248 mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
249 )
251 sigma = self._extract_sigma(main_trace) 1ab
252 first_sigma = self._extract_sigma(burnin_trace) 1ab
254 self.offset = final_state.offset # from the state because of buffer donation 1ab
255 self.sigest = sigest 1ab
256 self.sigma = sigma 1ab
257 self.first_sigma = first_sigma 1ab
259 self._x_train_fmt = x_train_fmt 1ab
260 self._splits = splits 1ab
261 self._main_trace = main_trace 1ab
262 self._mcmc_state = final_state 1ab
264 if x_test is not None: 1ab
265 yhat_test = self.predict(x_test) 1ab
266 self.yhat_test = yhat_test 1ab
267 self.yhat_test_mean = yhat_test.mean(axis=0) 1ab
269 @functools.cached_property 1ab
270 def yhat_train(self): 1ab
271 x_train = self._mcmc_state.X 1ab
272 return self._predict(self._main_trace, x_train) 1ab
274 @functools.cached_property 1ab
275 def yhat_train_mean(self): 1ab
276 return self.yhat_train.mean(axis=0) 1ab
278 def predict(self, x_test): 1ab
279 """
280 Compute the posterior mean at `x_test` for each MCMC iteration.
282 Parameters
283 ----------
284 x_test : array (p, m) or DataFrame
285 The test predictors.
287 Returns
288 -------
289 yhat_test : array (ndpost, m)
290 The conditional posterior mean at `x_test` for each MCMC iteration.
292 Raises
293 ------
294 ValueError
295 If `x_test` has a different format than `x_train`.
296 """
297 x_test, x_test_fmt = self._process_predictor_input(x_test) 1ab
298 if x_test_fmt != self._x_train_fmt: 1ab
299 raise ValueError( 1ab
300 f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
301 )
302 x_test = self._bin_predictors(x_test, self._splits) 1ab
303 return self._predict(self._main_trace, x_test) 1ab
305 @staticmethod 1ab
306 def _process_predictor_input(x): 1ab
307 if hasattr(x, 'columns'): 1ab
308 fmt = dict(kind='dataframe', columns=x.columns) 1ab
309 x = x.to_numpy().T 1ab
310 else:
311 fmt = dict(kind='array', num_covar=x.shape[0]) 1ab
312 x = jnp.asarray(x) 1ab
313 assert x.ndim == 2 1ab
314 return x, fmt 1ab
316 @staticmethod 1ab
317 def _process_response_input(y): 1ab
318 if hasattr(y, 'to_numpy'): 1ab
319 fmt = dict(kind='series', name=y.name) 1ab
320 y = y.to_numpy() 1ab
321 else:
322 fmt = dict(kind='array') 1ab
323 y = jnp.asarray(y) 1ab
324 assert y.ndim == 1 1ab
325 return y, fmt 1ab
327 @staticmethod 1ab
328 def _check_same_length(x1, x2): 1ab
329 get_length = lambda x: x.shape[-1] 1ab
330 assert get_length(x1) == get_length(x2) 1ab
332 @staticmethod 1ab
333 def _process_error_variance_settings( 1ab
334 x_train, y_train, sigest, sigdf, sigquant, lamda
335 ) -> tuple[Float32[Array, ''] | None, ...]:
336 if y_train.dtype == bool: 1ab
337 if sigest is not None: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true1ab
338 raise ValueError('Let `sigest=None` for binary regression')
339 if lamda is not None: 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true1ab
340 raise ValueError('Let `lamda=None` for binary regression')
341 return None, None 1ab
342 elif lamda is not None: 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true1ab
343 if sigest is not None:
344 raise ValueError('Let `sigest=None` if `lamda` is specified')
345 return lamda, None
346 else:
347 if sigest is not None: 347 ↛ 348line 347 didn't jump to line 348 because the condition on line 347 was never true1ab
348 sigest2 = jnp.square(sigest)
349 elif y_train.size < 2: 1ab
350 sigest2 = 1 1ab
351 elif y_train.size <= x_train.shape[0]: 1ab
352 sigest2 = jnp.var(y_train) 1ab
353 else:
354 x_centered = x_train.T - x_train.mean(axis=1) 1ab
355 y_centered = y_train - y_train.mean() 1ab
356 # centering is equivalent to adding an intercept column
357 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1ab
358 chisq = chisq.squeeze(0) 1ab
359 dof = len(y_train) - rank 1ab
360 sigest2 = chisq / dof 1ab
361 alpha = sigdf / 2 1ab
362 invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2 1ab
363 invchi2rid = invchi2 * sigdf 1ab
364 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1ab
366 @staticmethod 1ab
367 def _process_type_settings(y_train, type, w): 1ab
368 match type: 1ab
369 case 'wbart': 1ab
370 if y_train.dtype != jnp.float32: 1ab
371 raise TypeError( 1a
372 'Continuous regression requires y_train.dtype=float32,'
373 f' got {y_train.dtype=} instead.'
374 )
375 case 'pbart': 375 ↛ 385line 375 didn't jump to line 385 because the pattern on line 375 always matched1ab
376 if w is not None: 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true1ab
377 raise ValueError(
378 'Binary regression does not support weights, set `w=None`'
379 )
380 if y_train.dtype != bool: 1ab
381 raise TypeError( 1a
382 'Binary regression requires y_train.dtype=bool,'
383 f' got {y_train.dtype=} instead.'
384 )
385 case _:
386 raise ValueError(f'Invalid {type=}')
388 return y_train 1ab
390 @staticmethod 1ab
391 def _process_offset_settings( 1ab
392 y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
393 offset: float | Float32[Any, ''] | None,
394 ) -> Float32[Array, '']:
395 if offset is not None: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true1ab
396 return jnp.asarray(offset)
397 elif y_train.size < 1: 1ab
398 return jnp.array(0.0) 1ab
399 else:
400 mean = y_train.mean() 1ab
402 if y_train.dtype == bool: 1ab
403 return ndtri(mean) 1ab
404 else:
405 return mean 1ab
407 @staticmethod 1ab
408 def _process_leaf_sdev_settings( 1ab
409 y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
410 k: float,
411 ntree: int,
412 tau_num: FloatLike | None,
413 ):
414 if tau_num is None: 414 ↛ 422line 414 didn't jump to line 422 because the condition on line 414 was always true1ab
415 if y_train.dtype == bool: 1ab
416 tau_num = 3.0 1ab
417 elif y_train.size < 2: 1ab
418 tau_num = 1.0 1ab
419 else:
420 tau_num = (y_train.max() - y_train.min()) / 2 1ab
422 return tau_num / (k * math.sqrt(ntree)) 1ab
424 @staticmethod 1ab
425 def _determine_splits(x_train, usequants, numcut): 1ab
426 if usequants: 1ab
427 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1ab
428 else:
429 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1ab
431 @staticmethod 1ab
432 def _bin_predictors(x, splits): 1ab
433 return prepcovars.bin_predictors(x, splits) 1ab
435 @staticmethod 1ab
436 def _setup_mcmc( 1ab
437 x_train,
438 y_train,
439 offset,
440 w,
441 max_split,
442 lamda,
443 sigma_mu,
444 sigdf,
445 power,
446 base,
447 maxdepth,
448 ntree,
449 init_kw,
450 ):
451 depth = jnp.arange(maxdepth - 1) 1ab
452 p_nonterminal = base / (1 + depth).astype(float) ** power 1ab
454 if y_train.dtype == bool: 1ab
455 sigma2_alpha = None 1ab
456 sigma2_beta = None 1ab
457 else:
458 sigma2_alpha = sigdf / 2 1ab
459 sigma2_beta = lamda * sigma2_alpha 1ab
461 kw = dict( 1ab
462 X=x_train,
463 # copy y_train because it's going to be donated in the mcmc loop
464 y=jnp.array(y_train),
465 offset=offset,
466 error_scale=w,
467 max_split=max_split,
468 num_trees=ntree,
469 p_nonterminal=p_nonterminal,
470 sigma_mu2=jnp.square(sigma_mu),
471 sigma2_alpha=sigma2_alpha,
472 sigma2_beta=sigma2_beta,
473 min_points_per_leaf=5,
474 )
475 if init_kw is not None: 475 ↛ 477line 475 didn't jump to line 477 because the condition on line 475 was always true1ab
476 kw.update(init_kw) 1ab
477 return mcmcstep.init(**kw) 1ab
479 @staticmethod 1ab
480 def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw): 1ab
481 if isinstance(seed, jax.Array) and jnp.issubdtype( 481 ↛ 487line 481 didn't jump to line 487 because the condition on line 481 was always true1ab
482 seed.dtype, jax.dtypes.prng_key
483 ):
484 key = seed.copy() 1ab
485 # copy because the inner loop in run_mcmc will donate the buffer
486 else:
487 key = jax.random.key(seed)
489 kw = dict( 1ab
490 n_burn=nskip,
491 n_skip=keepevery,
492 inner_loop_length=printevery,
493 allow_overflow=True,
494 )
495 if printevery is not None: 1ab
496 kw.update(mcmcloop.make_print_callbacks()) 1ab
497 if run_mcmc_kw is not None: 497 ↛ 498line 497 didn't jump to line 498 because the condition on line 497 was never true1ab
498 kw.update(run_mcmc_kw)
500 return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw) 1ab
502 @staticmethod 1ab
503 def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None: 1ab
504 if trace['sigma2'] is None: 1ab
505 return None 1ab
506 else:
507 return jnp.sqrt(trace['sigma2']) 1ab
509 @staticmethod 1ab
510 def _predict(trace, x): 1ab
511 return mcmcloop.evaluate_trace(trace, x) 1ab
513 def _show_tree(self, i_sample, i_tree, print_all=False): 1ab
514 from . import debug
516 trace = self._main_trace
517 leaf_tree = trace['leaf_trees'][i_sample, i_tree]
518 var_tree = trace['var_trees'][i_sample, i_tree]
519 split_tree = trace['split_trees'][i_sample, i_tree]
520 debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
522 def _sigma_harmonic_mean(self, prior=False): 1ab
523 bart = self._mcmc_state
524 if prior:
525 alpha = bart['sigma2_alpha']
526 beta = bart['sigma2_beta']
527 else:
528 resid = bart['resid']
529 alpha = bart['sigma2_alpha'] + resid.size / 2
530 norm2 = jnp.dot(
531 resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
532 )
533 beta = bart['sigma2_beta'] + norm2 / 2
534 sigma2 = beta / alpha
535 return jnp.sqrt(sigma2)
537 def _compare_resid(self): 1ab
538 bart = self._mcmc_state 1ab
539 resid1 = bart.resid 1ab
541 trees = grove.evaluate_forest( 1ab
542 bart.X,
543 bart.forest.leaf_trees,
544 bart.forest.var_trees,
545 bart.forest.split_trees,
546 jnp.float32, # TODO remove these configurable dtypes around
547 )
549 if bart.z is not None: 1ab
550 ref = bart.z 1ab
551 else:
552 ref = bart.y 1ab
553 resid2 = ref - (trees + bart.offset) 1ab
555 return resid1, resid2 1ab
557 def _avg_acc(self): 1ab
558 trace = self._main_trace
560 def acc(prefix):
561 acc = trace[f'{prefix}_acc_count']
562 prop = trace[f'{prefix}_prop_count']
563 return acc.sum() / prop.sum()
565 return acc('grow'), acc('prune')
567 def _avg_prop(self): 1ab
568 trace = self._main_trace
570 def prop(prefix):
571 return trace[f'{prefix}_prop_count'].sum()
573 pgrow = prop('grow')
574 pprune = prop('prune')
575 total = pgrow + pprune
576 return pgrow / total, pprune / total
578 def _avg_move(self): 1ab
579 agrow, aprune = self._avg_acc()
580 pgrow, pprune = self._avg_prop()
581 return agrow * pgrow, aprune * pprune
583 def _depth_distr(self): 1ab
584 from . import debug
586 trace = self._main_trace
587 split_trees = trace['split_trees']
588 return debug.trace_depth_distr(split_trees)
590 def _points_per_leaf_distr(self): 1ab
591 from . import debug 1ab
593 return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X) 1ab
595 def _check_trees(self): 1ab
596 from . import debug 1ab
598 return debug.check_trace(self._main_trace, self._mcmc_state) 1ab
600 def _tree_goes_bad(self): 1ab
601 bad = self._check_trees().astype(bool)
602 bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
603 return bad & ~bad_before