11. BART¶
The bayestree submodule contains a class to set up a Gaussian process
regression with the BART kernel. See the bart,
barteasy and acic examples.
- class lsqfitgp.bayestree.bart(x_train, y_train, *, weights=None, fitkw={}, kernelkw={})¶
GP version of BART.
Evaluate a Gaussian process regression with a kernel which accurately approximates the infinite trees limit of BART. The hyperparameters are optimized to their marginal MAP.
- Parameters:
- x_train(n, p) array or dataframe
Observed covariates.
- y_train(n,) array
Observed outcomes.
- weights(n,) array
Weights used to rescale the error variance (as 1 / weight).
- fitkwdict
Additional arguments passed to
empbayes_fit, overrides the defaults.- kernelkwdict
Additional arguments passed to
BART, overrides the defaults.
See also
Notes
The tree splitting grid is set using quantiles of the observed covariates. This corresponds to settings
usequants=True,numcut=infin the R packages BayesTree and BART. Use thekernelkwparameter to customize the grid.- Attributes:
- meangvar
The prior mean.
- sigmagvar
The error term standard deviation. If there are weights, the sdev for each unit is obtained dividing
sigmaby sqrt(weight).- alphagvar
The numerator of the tree spawn probability (named
basein BayesTree and BART).- betagvar
The depth exponent of the tree spawn probability (named
powerin BayesTree and BART).- meansdevgvar
The prior standard deviation of the latent regression function.
- fitempbayes_fit
The hyperparameters fit object.
Methods
gp(*[, hp, x_test, weights, rng])Create a Gaussian process with the fitted hyperparameters.
data(*[, hp, rng])Get the data to be passed to
GP.predon a GP object returned bygp.pred(*[, hp, error, format, x_test, ...])Predict the outcome at given locations.
- data(*, hp='map', rng=None)¶
Get the data to be passed to
GP.predon a GP object returned bygp.- Parameters:
- hpstr or dict
The hyperparameters to use. If
'map', use the marginal maximum a posteriori. If'sample', sample hyperparameters from the posterior. If a dict, use the given hyperparameters.- rngnumpy.random.Generator, optional
Random number generator, used if
hp == 'sample'.
- Returns:
- datadict
A dictionary representing
y_trainin the format required by theGP.predmethod.
- gp(*, hp='map', x_test=None, weights=None, rng=None)¶
Create a Gaussian process with the fitted hyperparameters.
- Parameters:
- hpstr or dict
The hyperparameters to use. If
'map', use the marginal maximum a posteriori. If'sample', sample hyperparameters from the posterior. If a dict, use the given hyperparameters.- x_testarray or dataframe, optional
Additional covariates for “test points”.
- weightsarray, optional
Weights for the error variance on the test points.
- rngnumpy.random.Generator, optional
Random number generator, used if
hp == 'sample'.
- Returns:
- gpGP
A centered Gaussian process object. To add the mean, use the
meanattribute of thebartobject. The keys of the GP are ‘mean’, ‘*noise’, and ‘’, where the “*” stands either for ‘train’ or ‘test’.
- pred(*, hp='map', error=False, format='matrices', x_test=None, weights=None, rng=None)¶
Predict the outcome at given locations.
- Parameters:
- hpstr or dict
The hyperparameters to use. If
'map', use the marginal maximum a posteriori. If'sample', sample hyperparameters from the posterior. If a dict, use the given hyperparameters.- errorbool
If
False(default), make a prediction for the latent mean. IfTrue, add the error term.- format{‘matrices’, ‘gvar’}
If ‘matrices’ (default), return the mean and covariance matrix separately. If ‘gvar’, return an array of `GVar`s.
- x_testarray or dataframe, optional
Covariates for the locations where the prediction is computed. If not specified, predict at the data covariates.
- weightsarray, optional
Weights for the error variance on the test points.
- rngnumpy.random.Generator, optional
Random number generator, used if
hp == 'sample'.
- Returns:
- If
formatis ‘matrices’ (default): - mean, covarrays
The mean and covariance matrix of the Normal posterior distribution over the regression function at the specified locations.
- If
formatis ‘gvar’: - outarray of
GVar The same distribution represented as an array of
GVarobjects.
- If