Interface¶
Implement a user interface that mimics the R BART package.
- class bartz.BART.gbart(x_train, y_train, *, x_test=None, type='wbart', usequants=False, sigest=None, sigdf=3, sigquant=0.9, k=2, power=2, base=0.95, lamda=None, tau_num=None, offset=None, w=None, ntree=200, numcut=255, ndpost=1000, nskip=100, keepevery=1, printevery=100, seed=0, maxdepth=6, init_kw=None, run_mcmc_kw=None)[source]¶
Nonparametric regression with Bayesian Additive Regression Trees (BART).
Regress
y_train
onx_train
with a latent mean function represented as a sum of decision trees. The inference is carried out by sampling the posterior distribution of the tree ensemble with an MCMC.- Parameters:
x_train (array (p, n) or DataFrame) – The training predictors.
y_train (array (n,) or Series) – The training responses.
x_test (array (p, m) or DataFrame, optional) – The test predictors.
type (
Literal
['wbart'
,'pbart'
]) – The type of regression. ‘wbart’ for continuous regression, ‘pbart’ for binary regression with probit link.usequants (bool, default False) – Whether to use predictors quantiles instead of a uniform grid to bin predictors.
sigest (float, optional) – An estimate of the residual standard deviation on
y_train
, used to setlamda
. If not specified, it is estimated by linear regression (with intercept, and without taking into accountw
). Ify_train
has less than two elements, it is set to 1. If n <= p, it is set to the standard deviation ofy_train
. Ignored iflamda
is specified.sigdf (int, default 3) – The degrees of freedom of the scaled inverse-chisquared prior on the noise variance.
sigquant (float, default 0.9) – The quantile of the prior on the noise variance that shall match
sigest
to set the scale of the prior. Ignored iflamda
is specified.k (float, default 2) – The inverse scale of the prior standard deviation on the latent mean function, relative to half the observed range of
y_train
. Ify_train
has less than two elements,k
is ignored and the scale is set to 1.power (float, default 2)
base (float, default 0.95) – Parameters of the prior on tree node generation. The probability that a node at depth
d
(0-based) is non-terminal isbase / (1 + d) ** power
.lamda (
float
|Float[Any, '']
|None
) – The prior harmonic mean of the error variance. (The harmonic mean of x is 1/mean(1/x).) If not specified, it is set based onsigest
andsigquant
.tau_num (
float
|Float[Any, '']
|None
) – The numerator in the expression that determines the prior standard deviation of leaves. If not specified, default to(max(y_train) - min(y_train)) / 2
(or 1 ify_train
has less than two elements) for continuous regression, and 3 for binary regression.offset (
float
|Float[Any, '']
|None
) – The prior mean of the latent mean function. If not specified, it is set to the mean ofy_train
for continuous regression, and toPhi^-1(mean(y_train))
for binary regression. Ify_train
is empty,offset
is set to 0.w (array (n,), optional) – Coefficients that rescale the error standard deviation on each datapoint. Not specifying
w
is equivalent to setting it to 1 for all datapoints. Note:w
is ignored in the automatic determination ofsigest
, so either the weights should be O(1), orsigest
should be specified by the user.ntree (int, default 200) – The number of trees used to represent the latent mean function.
numcut (int, default 255) –
If
usequants
isFalse
: the exact number of cutpoints used to bin the predictors, ranging between the minimum and maximum observed values (excluded).If
usequants
isTrue
: the maximum number of cutpoints to use for binning the predictors. Each predictor is binned such that its distribution inx_train
is approximately uniform across bins. The number of bins is at most the number of unique values appearing inx_train
, ornumcut + 1
.Before running the algorithm, the predictors are compressed to the smallest integer type that fits the bin indices, so
numcut
is best set to the maximum value of an unsigned integer type.ndpost (int, default 1000) – The number of MCMC samples to save, after burn-in.
nskip (int, default 100) – The number of initial MCMC samples to discard as burn-in.
keepevery (int, default 1) – The thinning factor for the MCMC samples, after burn-in.
printevery (int or None, default 100) –
The number of iterations (including thinned-away ones) between each log line. Set to
None
to disable logging.printevery
has a few unexpected side effects. On cpu, interrupting with ^C halts the MCMC only on the next log. And the total number of iterations is a multiple ofprintevery
, so ifnskip + keepevery * ndpost
is not a multiple ofprintevery
, some of the last iterations will not be saved.seed (int or jax random key, default 0) – The seed for the random number generator.
maxdepth (int, default 6) – The maximum depth of the trees. This is 1-based, so with the default
maxdepth=6
, the depths of the levels range from 0 to 5.init_kw (dict) – Additional arguments passed to
mcmcstep.init
.run_mcmc_kw (dict) – Additional arguments passed to
mcmcloop.run_mcmc
.
- Variables:
yhat_train (array (ndpost, n)) – The conditional posterior mean at
x_train
for each MCMC iteration.yhat_train_mean (array (n,)) – The marginal posterior mean at
x_train
.yhat_test (array (ndpost, m)) – The conditional posterior mean at
x_test
for each MCMC iteration.yhat_test_mean (array (m,)) – The marginal posterior mean at
x_test
.sigma (array (ndpost,)) – The standard deviation of the error.
first_sigma (array (nskip,)) – The standard deviation of the error in the burn-in phase.
offset (float) – The prior mean of the latent mean function.
sigest (float or None) – The estimated standard deviation of the error used to set
lamda
.
Notes
This interface imitates the function
gbart
from the R package BART, but with these differences:If
x_train
andx_test
are matrices, they have one predictor per row instead of per column.If
type
is not specified, it is determined solely based on the data type ofy_train
, and not on whether it contains only two unique values.If
usequants=False
, R BART switches to quantiles anyway if there are less predictor values than the required number of bins, while bartz always follows the specification.The error variance parameter is called
lamda
instead oflambda
.rm_const
is alwaysFalse
.The default
numcut
is 255 instead of 100.A lot of functionality is missing (e.g., variable selection).
There are some additional attributes, and some missing.
The trees have a maximum depth.
- predict(x_test)[source]¶
Compute the posterior mean at
x_test
for each MCMC iteration.- Parameters:
x_test (array (p, m) or DataFrame) – The test predictors.
- Returns:
yhat_test (array (ndpost, m)) – The conditional posterior mean at
x_test
for each MCMC iteration.- Raises:
ValueError – If
x_test
has a different format thanx_train
.