11. BART

The bayestree submodule contains a class to set up a Gaussian process regression with the BART kernel. See the bart, barteasy and acic examples.

class lsqfitgp.bayestree.bart(x_train, y_train, *, weights=None, fitkw={}, kernelkw={})

GP version of BART.

Evaluate a Gaussian process regression with a kernel which accurately approximates the infinite trees limit of BART. The hyperparameters are optimized to their marginal MAP.

Parameters:
x_train(n, p) array or dataframe

Observed covariates.

y_train(n,) array

Observed outcomes.

weights(n,) array

Weights used to rescale the error variance (as 1 / weight).

fitkwdict

Additional arguments passed to empbayes_fit, overrides the defaults.

kernelkwdict

Additional arguments passed to BART, overrides the defaults.

See also

lsqfitgp.BART

Notes

The tree splitting grid is set using quantiles of the observed covariates. This corresponds to settings usequants=True, numcut=inf in the R packages BayesTree and BART. Use the kernelkw parameter to customize the grid.

Attributes:
meangvar

The prior mean.

sigmagvar

The error term standard deviation. If there are weights, the sdev for each unit is obtained dividing sigma by sqrt(weight).

alphagvar

The numerator of the tree spawn probability (named base in BayesTree and BART).

betagvar

The depth exponent of the tree spawn probability (named power in BayesTree and BART).

meansdevgvar

The prior standard deviation of the latent regression function.

fitempbayes_fit

The hyperparameters fit object.

Methods

gp(*[, hp, x_test, weights, rng])

Create a Gaussian process with the fitted hyperparameters.

data(*[, hp, rng])

Get the data to be passed to GP.pred on a GP object returned by gp.

pred(*[, hp, error, format, x_test, ...])

Predict the outcome at given locations.

data(*, hp='map', rng=None)

Get the data to be passed to GP.pred on a GP object returned by gp.

Parameters:
hpstr or dict

The hyperparameters to use. If 'map', use the marginal maximum a posteriori. If 'sample', sample hyperparameters from the posterior. If a dict, use the given hyperparameters.

rngnumpy.random.Generator, optional

Random number generator, used if hp == 'sample'.

Returns:
datadict

A dictionary representing y_train in the format required by the GP.pred method.

gp(*, hp='map', x_test=None, weights=None, rng=None)

Create a Gaussian process with the fitted hyperparameters.

Parameters:
hpstr or dict

The hyperparameters to use. If 'map', use the marginal maximum a posteriori. If 'sample', sample hyperparameters from the posterior. If a dict, use the given hyperparameters.

x_testarray or dataframe, optional

Additional covariates for “test points”.

weightsarray, optional

Weights for the error variance on the test points.

rngnumpy.random.Generator, optional

Random number generator, used if hp == 'sample'.

Returns:
gpGP

A centered Gaussian process object. To add the mean, use the mean attribute of the bart object. The keys of the GP are ‘mean’, ‘*noise’, and ‘’, where the “*” stands either for ‘train’ or ‘test’.

pred(*, hp='map', error=False, format='matrices', x_test=None, weights=None, rng=None)

Predict the outcome at given locations.

Parameters:
hpstr or dict

The hyperparameters to use. If 'map', use the marginal maximum a posteriori. If 'sample', sample hyperparameters from the posterior. If a dict, use the given hyperparameters.

errorbool

If False (default), make a prediction for the latent mean. If True, add the error term.

format{‘matrices’, ‘gvar’}

If ‘matrices’ (default), return the mean and covariance matrix separately. If ‘gvar’, return an array of `GVar`s.

x_testarray or dataframe, optional

Covariates for the locations where the prediction is computed. If not specified, predict at the data covariates.

weightsarray, optional

Weights for the error variance on the test points.

rngnumpy.random.Generator, optional

Random number generator, used if hp == 'sample'.

Returns:
If format is ‘matrices’ (default):
mean, covarrays

The mean and covariance matrix of the Normal posterior distribution over the regression function at the specified locations.

If format is ‘gvar’:
outarray of GVar

The same distribution represented as an array of GVar objects.