Coverage for src/lsqfitgp/bayestree/_bart.py: 84%
143 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
1# lsqfitgp/bayestree/_bart.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1fabcde
22import numpy 1fabcde
23from jax import numpy as jnp 1fabcde
24import jax 1fabcde
25import gvar 1fabcde
27from .. import copula 1fabcde
28from .. import _kernels 1fabcde
29from .. import _fit 1fabcde
30from .. import _array 1fabcde
31from .. import _GP 1fabcde
32from .. import _fastraniter 1fabcde
34# TODO I added a lot of functionality to bcf. The easiest way to port it over is
35# adding the option in bcf to drop the second bart model and its associated
36# hypers, and then write bart as a simple convenience wrapper-subclass over bcf.
37# (also the option include_pi='none'.)
39class bart: 1fabcde
41 def __init__(self, 1fabcde
42 x_train,
43 y_train,
44 *,
45 weights=None,
46 fitkw={},
47 kernelkw={},
48 marginalize_mean=True,
49 ):
50 """
51 Nonparametric Bayesian regression with a GP version of BART.
53 Evaluate a Gaussian process regression with a kernel which accurately
54 approximates the infinite trees limit of BART. The hyperparameters are
55 optimized to their marginal MAP.
57 Parameters
58 ----------
59 x_train : (n, p) array or dataframe
60 Observed covariates.
61 y_train : (n,) array
62 Observed outcomes.
63 weights : (n,) array
64 Weights used to rescale the error variance (as 1 / weight).
65 fitkw : dict
66 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides
67 the defaults.
68 kernelkw : dict
69 Additional arguments passed to `~lsqfitgp.BART`, overrides the
70 defaults.
71 marginalize_mean : bool
72 If True (default), marginalize the intercept of the model.
74 Notes
75 -----
76 The regression model is:
78 .. math::
79 y_i &= \\mu + \\lambda f(\\mathbf x_i) + \\varepsilon_i, \\\\
80 \\varepsilon_i &\\overset{\\mathrm{i.i.d.}}{\\sim}
81 N(0, \\sigma^2 / w_i), \\\\
82 \\mu &\\sim N(
83 (\\max(\\mathbf y) + \\min(\\mathbf y)) / 2,
84 (\\max(\\mathbf y) - \\min(\\mathbf y))^2 / 4
85 ), \\\\
86 \\log \\sigma^2 &\\sim N(
87 \\log(\\overline{w(y - \\bar y)^2}),
88 4
89 ), \\\\
90 \\log \\lambda &\\sim N(
91 \\log ((\\max(\\mathbf y) - \\min(\\mathbf y)) / 4),
92 4
93 ), \\\\
94 f &\\sim \\mathrm{GP}(
95 0,
96 \\mathrm{BART}(\\alpha,\\beta)
97 ), \\\\
98 \\alpha &\\sim \\mathrm{B}(2, 1), \\\\
99 \\beta &\\sim \\mathrm{IG}(1, 1).
101 To make the inference, :math:`(f, \\boldsymbol\\varepsilon, \\mu)` are
102 marginalized analytically, and the marginal posterior mode of
103 :math:`(\\sigma, \\lambda, \\alpha, \\beta)` is found by numerical
104 minimization, after transforming them to express their prior as a
105 Gaussian copula. Their marginal posterior covariance matrix is estimated
106 with an approximation of the hessian inverse. See
107 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize
108 this procedure.
110 The tree splitting grid of the BART kernel is set using quantiles of the
111 observed covariates. This corresponds to settings ``usequants=True``,
112 ``numcut=inf`` in the R packages BayesTree and BART. Use the
113 ``kernelkw`` parameter to customize the grid.
115 Attributes
116 ----------
117 mean : gvar
118 The prior mean :math:`\\mu`.
119 sigma : float or gvar
120 The error term standard deviation :math:`\\sigma`. If there are
121 weights, the sdev for each unit is obtained dividing ``sigma`` by
122 sqrt(weight).
123 alpha : gvar
124 The numerator of the tree spawn probability :math:`\\alpha` (named
125 ``base`` in BayesTree and BART).
126 beta : gvar
127 The depth exponent of the tree spawn probability :math:`\\beta`
128 (named ``power`` in BayesTree and BART).
129 meansdev : gvar
130 The prior standard deviation :math:`\\lambda` of the latent
131 regression function.
132 fit : empbayes_fit
133 The hyperparameters fit object.
135 Methods
136 -------
137 gp :
138 Create a GP object.
139 data :
140 Creates the dictionary to be passed to `GP.pred` to represent
141 ``y_train``.
142 pred :
143 Evaluate the regression function at given locations.
145 See also
146 --------
147 lsqfitgp.BART
149 """
151 # convert covariates to StructuredArray
152 x_train = self._to_structured(x_train) 1abcde
154 # convert outcomes to 1d array
155 if hasattr(y_train, 'to_numpy'): 1abcde
156 y_train = y_train.to_numpy() 1a
157 y_train = y_train.squeeze() # for dataframes 1a
158 y_train = jnp.asarray(y_train) 1abcde
159 assert y_train.shape == x_train.shape 1abcde
161 # check weights
162 self._no_weights = weights is None 1abcde
163 if self._no_weights: 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was always true1abcde
164 weights = jnp.ones_like(y_train) 1abcde
165 assert weights.shape == y_train.shape 1abcde
167 # prior mean and variance
168 ymin = jnp.min(y_train) 1abcde
169 ymax = jnp.max(y_train) 1abcde
170 mu_mu = (ymax + ymin) / 2 1abcde
171 k_sigma_mu = (ymax - ymin) / 2 1abcde
173 # splitting points and indices
174 splits = _kernels.BART.splits_from_coord(x_train) 1abcde
175 i_train = self._toindices(x_train, splits) 1abcde
177 # prior on hyperparams
178 sigma2_priormean = numpy.mean((y_train - y_train.mean()) ** 2 * weights) 1abcde
179 hyperprior = copula.makedict({ 1abcde
180 'alpha': copula.beta(2, 1), # base of tree gen prob
181 'beta': copula.invgamma(1, 1), # exponent of tree gen prob
182 'log(k)': gvar.gvar(numpy.log(2), 2), # denominator of prior sdev
183 'log(sigma2)': gvar.gvar(numpy.log(sigma2_priormean), 2),
184 # i.i.d. error variance, scaled with weights
185 'mean': gvar.gvar(mu_mu, k_sigma_mu), # mean of the GP
186 })
187 if marginalize_mean: 187 ↛ 191line 187 didn't jump to line 191 because the condition on line 187 was always true1abcde
188 hyperprior.pop('mean') 1abcde
190 # GP factory
191 def makegp(hp, *, i_train, weights, splits, **_): 1abcde
192 kw = dict( 1abcde
193 alpha=hp['alpha'], beta=hp['beta'],
194 maxd=10, reset=[2, 4, 6, 8],
195 )
196 kw.update(kernelkw) 1abcde
197 kernel = _kernels.BART(splits=splits, indices=True, **kw) 1abcde
198 kernel *= (k_sigma_mu / hp['k']) ** 2 1abcde
200 gp = (_GP 1abcde
201 .GP(kernel, checkpos=False, checksym=False, solver='chol')
202 .addx(i_train, 'trainmean')
203 .addcov(jnp.diag(hp['sigma2'] / weights), 'trainnoise')
204 )
205 pieces = {'trainmean': 1, 'trainnoise': 1} 1abcde
206 if 'mean' not in hp: 206 ↛ 209line 206 didn't jump to line 209 because the condition on line 206 was always true1abcde
207 gp = gp.addcov(k_sigma_mu ** 2, 'mean') 1abcde
208 pieces.update({'mean': 1}) 1abcde
209 return gp.addtransf(pieces, 'train') 1abcde
211 # data factory
212 def info(hp, *, mu_mu, **_): 1abcde
213 return {'train': y_train - hp.get('mean', mu_mu)} 1abcde
215 # fit hyperparameters
216 gpkw = dict( 1abcde
217 i_train=i_train,
218 weights=weights,
219 splits=splits,
220 mu_mu=mu_mu,
221 )
222 options = dict( 1abcde
223 verbosity=3,
224 raises=False,
225 minkw=dict(method='l-bfgs-b', options=dict(maxls=4, maxiter=100)),
226 mlkw=dict(epsrel=0),
227 forward=True,
228 gpfactorykw=gpkw,
229 )
230 options.update(fitkw) 1abcde
231 fit = _fit.empbayes_fit(hyperprior, makegp, info, **options) 1abcde
233 # extract hyperparameters from minimization result
234 self.sigma = gvar.sqrt(fit.p['sigma2']) 1abcde
235 self.alpha = fit.p['alpha'] 1abcde
236 self.beta = fit.p['beta'] 1abcde
237 self.meansdev = k_sigma_mu / fit.p['k'] 1abcde
238 self.mean = fit.p.get('mean', mu_mu) 1abcde
240 # set public attributes
241 self.fit = fit 1abcde
243 # set private attributes
244 self._ystd = y_train.std() 1abcde
246 def _gethp(self, hp, rng): 1fabcde
247 if not isinstance(hp, str): 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true1abcde
248 return hp
249 elif hp == 'map': 249 ↛ 251line 249 didn't jump to line 251 because the condition on line 249 was always true1abcde
250 return self.fit.pmean 1abcde
251 elif hp == 'sample':
252 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng)
253 else:
254 raise KeyError(hp)
256 def gp(self, *, hp='map', x_test=None, weights=None, rng=None): 1fabcde
257 """
258 Create a Gaussian process with the fitted hyperparameters.
260 Parameters
261 ----------
262 hp : str or dict
263 The hyperparameters to use. If ``'map'``, use the marginal maximum a
264 posteriori. If ``'sample'``, sample hyperparameters from the
265 posterior. If a dict, use the given hyperparameters.
266 x_test : array or dataframe, optional
267 Additional covariates for "test points".
268 weights : array, optional
269 Weights for the error variance on the test points.
270 rng : numpy.random.Generator, optional
271 Random number generator, used if ``hp == 'sample'``.
273 Returns
274 -------
275 gp : GP
276 A centered Gaussian process object. To add the mean, use the
277 ``mean`` attribute of the `bart` object. The keys of the GP are
278 'Xmean', 'Xnoise', and 'X', where the "X" stands either for 'train'
279 or 'test', and X = Xmean + Xnoise.
280 """
282 hp = self._gethp(hp, rng)
283 return self._gp(hp, x_test, weights, self.fit.gpfactorykw)
285 def _gp(self, hp, x_test, weights, gpfactorykw): 1fabcde
287 # create GP object
288 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde
290 # add test points
291 if x_test is not None: 1abcde
293 # convert covariates to indices
294 x_test = self._to_structured(x_test) 1a
295 i_test = self._toindices(x_test, gpfactorykw['splits']) 1a
296 assert i_test.dtype == gpfactorykw['i_train'].dtype 1a
298 # check weights
299 if weights is not None: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true1a
300 weights = jnp.asarray(weights)
301 assert weights.shape == i_test.shape
302 else:
303 weights = jnp.ones(i_test.shape) 1a
305 # add test points
306 gp = (gp 1a
307 .addx(i_test, 'testmean')
308 .addcov(jnp.diag(hp['sigma2'] / weights), 'testnoise')
309 )
310 pieces = {'testmean': 1, 'testnoise': 1} 1a
311 if 'mean' not in hp: 311 ↛ 313line 311 didn't jump to line 313 because the condition on line 311 was always true1a
312 pieces.update({'mean': 1}) 1a
313 gp = gp.addtransf(pieces, 'test') 1a
315 return gp 1abcde
317 def data(self, *, hp='map', rng=None): 1fabcde
318 """
319 Get the data to be passed to `GP.pred` on a GP object returned by `gp`.
321 Parameters
322 ----------
323 hp : str or dict
324 The hyperparameters to use. If ``'map'``, use the marginal maximum a
325 posteriori. If ``'sample'``, sample hyperparameters from the
326 posterior. If a dict, use the given hyperparameters.
327 rng : numpy.random.Generator, optional
328 Random number generator, used if ``hp == 'sample'``.
330 Returns
331 -------
332 data : dict
333 A dictionary representing ``y_train`` in the format required by the
334 `GP.pred` method.
335 """
337 hp = self._gethp(hp, rng)
338 return self.fit.data(hp, **self.fit.gpfactorykw)
340 def pred(self, *, hp='map', error=False, format='matrices', x_test=None, 1fabcde
341 weights=None, rng=None):
342 """
343 Predict the outcome at given locations.
345 Parameters
346 ----------
347 hp : str or dict
348 The hyperparameters to use. If ``'map'``, use the marginal maximum a
349 posteriori. If ``'sample'``, sample hyperparameters from the
350 posterior. If a dict, use the given hyperparameters.
351 error : bool
352 If ``False`` (default), make a prediction for the latent mean. If
353 ``True``, add the error term.
354 format : {'matrices', 'gvar'}
355 If 'matrices' (default), return the mean and covariance matrix
356 separately. If 'gvar', return an array of gvars.
357 x_test : array or dataframe, optional
358 Covariates for the locations where the prediction is computed. If
359 not specified, predict at the data covariates.
360 weights : array, optional
361 Weights for the error variance on the test points.
362 rng : numpy.random.Generator, optional
363 Random number generator, used if ``hp == 'sample'``.
365 Returns
366 -------
367 If ``format`` is 'matrices' (default):
369 mean, cov : arrays
370 The mean and covariance matrix of the Normal posterior distribution
371 over the regression function at the specified locations.
373 If ``format`` is 'gvar':
375 out : array of `GVar`
376 The same distribution represented as an array of `GVar` objects.
377 """
379 # TODO it is a bit confusing that if x_test=None and error=True, the
380 # prediction returns y_train exactly, instead of hypothetical new
381 # observations at the same covariates.
383 hp = self._gethp(hp, rng) 1abcde
384 if x_test is not None: 1abcde
385 x_test = self._to_structured(x_test) 1a
386 mean, cov = self._pred(hp, x_test, weights, self.fit.gpfactorykw, bool(error)) 1abcde
388 if format == 'gvar': 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true1abcde
389 return gvar.gvar(mean, cov, fast=True)
390 elif format == 'matrices': 390 ↛ 393line 390 didn't jump to line 393 because the condition on line 390 was always true1abcde
391 return mean, cov 1abcde
392 else:
393 raise KeyError(format)
395 @functools.cached_property 1fabcde
396 def _pred(self): 1fabcde
398 @functools.partial(jax.jit, static_argnums=(4,)) 1abcde
399 def _pred(hp, x_test, weights, gpfactorykw, error): 1abcde
400 gp = self._gp(hp, x_test, weights, gpfactorykw) 1abcde
401 data = self.fit.data(hp, **gpfactorykw) 1abcde
402 if x_test is None: 1abcde
403 label = 'train' 1bcde
404 else:
405 label = 'test' 1a
406 if not error: 1abcde
407 label += 'mean' 1bcde
408 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde
409 return outmean + hp.get('mean', gpfactorykw['mu_mu']), outcov 1abcde
411 return _pred 1abcde
413 @classmethod 1fabcde
414 def _to_structured(cls, x): 1fabcde
416 # convert to StructuredArray
417 if hasattr(x, 'columns'): 1abcde
418 x = _array.StructuredArray.from_dataframe(x) 1a
419 elif x.dtype.names is None: 1abcde
420 x = _array.unstructured_to_structured(x) 1bcde
421 else:
422 x = _array.StructuredArray(x) 1a
424 # check
425 assert x.ndim == 1 1abcde
426 def check_numerical(path, dtype): 1abcde
427 if not numpy.issubdtype(dtype, numpy.number): 427 ↛ 428line 427 didn't jump to line 428 because the condition on line 427 was never true1abcde
428 raise TypeError(f'covariate `{path}` is not numerical')
429 cls._walk_dtype(x.dtype, check_numerical) 1abcde
431 return x 1abcde
433 @classmethod 1fabcde
434 def _walk_dtype(cls, dtype, task, path=None): 1fabcde
435 if dtype.names is None: 1abcde
436 task(path, dtype) 1abcde
437 else:
438 for name in dtype.names: 1abcde
439 subpath = name if path is None else path + ':' + name 1abcde
440 cls._walk_dtype(dtype[name], task, subpath) 1abcde
442 @staticmethod 1fabcde
443 def _toindices(x, splits): 1fabcde
444 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde
445 return _array.unstructured_to_structured(ix, names=x.dtype.names) 1abcde
447 def __repr__(self): 1fabcde
448 out = f"""BART fit: 1a
449alpha = {self.alpha} (0 -> intercept only, 1 -> any)
450beta = {self.beta} (0 -> any, ∞ -> no interactions)
451mean = {self.mean}
452latent sdev = {self.meansdev} (large -> conservative extrapolation)
453data total sdev = {self._ystd:.3g}"""
455 if self._no_weights: 455 ↛ 459line 455 didn't jump to line 459 because the condition on line 455 was always true1a
456 out += f""" 1a
457error sdev = {self.sigma}"""
458 else:
459 weights = numpy.array(self.fit.gpfactorykw['weights'])
460 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights))
461 out += f"""
462error sdev (avg weighted) = {avgsigma}
463error sdev (unweighted) = {self.sigma}"""
465 return out 1a