11. BART¶
The bayestree
submodule contains a class to set up a Gaussian process
regression with the BART
kernel. See the bart,
barteasy and acic examples.
- class lsqfitgp.bayestree.bart(x_train, y_train, *, weights=None, fitkw={}, kernelkw={})¶
GP version of BART.
Evaluate a Gaussian process regression with a kernel which accurately approximates the infinite trees limit of BART. The hyperparameters are optimized to their marginal MAP.
- Parameters:
- x_train(n, p) array or dataframe
Observed covariates.
- y_train(n,) array
Observed outcomes.
- weights(n,) array
Weights used to rescale the error variance (as 1 / weight).
- fitkwdict
Additional arguments passed to
empbayes_fit
, overrides the defaults.- kernelkwdict
Additional arguments passed to
BART
, overrides the defaults.
See also
Notes
The tree splitting grid is set using quantiles of the observed covariates. This corresponds to settings
usequants=True
,numcut=inf
in the R packages BayesTree and BART. Use thekernelkw
parameter to customize the grid.- Attributes:
- meangvar
The prior mean.
- sigmagvar
The error term standard deviation. If there are weights, the sdev for each unit is obtained dividing
sigma
by sqrt(weight).- alphagvar
The numerator of the tree spawn probability (named
base
in BayesTree and BART).- betagvar
The depth exponent of the tree spawn probability (named
power
in BayesTree and BART).- meansdevgvar
The prior standard deviation of the latent regression function.
- fitempbayes_fit
The hyperparameters fit object.
Methods
gp
(*[, hp, x_test, weights, rng])Create a Gaussian process with the fitted hyperparameters.
data
(*[, hp, rng])Get the data to be passed to
GP.pred
on a GP object returned bygp
.pred
(*[, hp, error, format, x_test, ...])Predict the outcome at given locations.
- data(*, hp='map', rng=None)¶
Get the data to be passed to
GP.pred
on a GP object returned bygp
.- Parameters:
- hpstr or dict
The hyperparameters to use. If
'map'
, use the marginal maximum a posteriori. If'sample'
, sample hyperparameters from the posterior. If a dict, use the given hyperparameters.- rngnumpy.random.Generator, optional
Random number generator, used if
hp == 'sample'
.
- Returns:
- datadict
A dictionary representing
y_train
in the format required by theGP.pred
method.
- gp(*, hp='map', x_test=None, weights=None, rng=None)¶
Create a Gaussian process with the fitted hyperparameters.
- Parameters:
- hpstr or dict
The hyperparameters to use. If
'map'
, use the marginal maximum a posteriori. If'sample'
, sample hyperparameters from the posterior. If a dict, use the given hyperparameters.- x_testarray or dataframe, optional
Additional covariates for “test points”.
- weightsarray, optional
Weights for the error variance on the test points.
- rngnumpy.random.Generator, optional
Random number generator, used if
hp == 'sample'
.
- Returns:
- gpGP
A centered Gaussian process object. To add the mean, use the
mean
attribute of thebart
object. The keys of the GP are ‘mean’, ‘*noise’, and ‘’, where the “*” stands either for ‘train’ or ‘test’.
- pred(*, hp='map', error=False, format='matrices', x_test=None, weights=None, rng=None)¶
Predict the outcome at given locations.
- Parameters:
- hpstr or dict
The hyperparameters to use. If
'map'
, use the marginal maximum a posteriori. If'sample'
, sample hyperparameters from the posterior. If a dict, use the given hyperparameters.- errorbool
If
False
(default), make a prediction for the latent mean. IfTrue
, add the error term.- format{‘matrices’, ‘gvar’}
If ‘matrices’ (default), return the mean and covariance matrix separately. If ‘gvar’, return an array of `GVar`s.
- x_testarray or dataframe, optional
Covariates for the locations where the prediction is computed. If not specified, predict at the data covariates.
- weightsarray, optional
Weights for the error variance on the test points.
- rngnumpy.random.Generator, optional
Random number generator, used if
hp == 'sample'
.
- Returns:
- If
format
is ‘matrices’ (default): - mean, covarrays
The mean and covariance matrix of the Normal posterior distribution over the regression function at the specified locations.
- If
format
is ‘gvar’: - outarray of
GVar
The same distribution represented as an array of
GVar
objects.
- If