Coverage for src/lsqfitgp/bayestree/_bart.py: 84%
143 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/bayestree/_bart.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1fabcde
22import numpy 1fabcde
23from jax import numpy as jnp 1fabcde
24import jax 1fabcde
25import gvar 1fabcde
27from .. import copula 1fabcde
28from .. import _kernels 1fabcde
29from .. import _fit 1fabcde
30from .. import _array 1fabcde
31from .. import _GP 1fabcde
32from .. import _fastraniter 1fabcde
34# TODO I added a lot of functionality to bcf. The easiest way to port it over is
35# adding the option in bcf to drop the second bart model and its associated
36# hypers, and then write bart as a simple convenience wrapper-subclass over bcf.
38class bart: 1fabcde
40 def __init__(self, 1fabcde
41 x_train,
42 y_train,
43 *,
44 weights=None,
45 fitkw={},
46 kernelkw={},
47 marginalize_mean=True,
48 ):
49 """
50 Nonparametric Bayesian regression with a GP version of BART.
52 Evaluate a Gaussian process regression with a kernel which accurately
53 approximates the infinite trees limit of BART. The hyperparameters are
54 optimized to their marginal MAP.
56 Parameters
57 ----------
58 x_train : (n, p) array or dataframe
59 Observed covariates.
60 y_train : (n,) array
61 Observed outcomes.
62 weights : (n,) array
63 Weights used to rescale the error variance (as 1 / weight).
64 fitkw : dict
65 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides
66 the defaults.
67 kernelkw : dict
68 Additional arguments passed to `~lsqfitgp.BART`, overrides the
69 defaults.
70 marginalize_mean : bool
71 If True (default), marginalize the intercept of the model.
73 Notes
74 -----
75 The regression model is:
77 .. math::
78 y_i &= \\mu + \\lambda f(\\mathbf x_i) + \\varepsilon_i, \\\\
79 \\varepsilon_i &\\overset{\\mathrm{i.i.d.}}{\\sim}
80 N(0, \\sigma^2 / w_i), \\\\
81 \\mu &\\sim N(
82 (\\max(\\mathbf y) + \\min(\\mathbf y)) / 2,
83 (\\max(\\mathbf y) - \\min(\\mathbf y))^2 / 4
84 ), \\\\
85 \\log \\sigma^2 &\\sim N(
86 \\log(\\overline{w(y - \\bar y)^2}),
87 4
88 ), \\\\
89 \\log \\lambda &\\sim N(
90 \\log ((\\max(\\mathbf y) - \\min(\\mathbf y)) / 4),
91 4
92 ), \\\\
93 f &\\sim \\mathrm{GP}(
94 0,
95 \\mathrm{BART}(\\alpha,\\beta)
96 ), \\\\
97 \\alpha &\\sim \\mathrm{B}(2, 1), \\\\
98 \\beta &\\sim \\mathrm{IG}(1, 1).
100 To make the inference, :math:`(f, \\boldsymbol\\varepsilon, \\mu)` are
101 marginalized analytically, and the marginal posterior mode of
102 :math:`(\\sigma, \\lambda, \\alpha, \\beta)` is found by numerical
103 minimization, after transforming them to express their prior as a
104 Gaussian copula. Their marginal posterior covariance matrix is estimated
105 with an approximation of the hessian inverse. See
106 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize
107 this procedure.
109 The tree splitting grid of the BART kernel is set using quantiles of the
110 observed covariates. This corresponds to settings ``usequants=True``,
111 ``numcut=inf`` in the R packages BayesTree and BART. Use the
112 ``kernelkw`` parameter to customize the grid.
114 Attributes
115 ----------
116 mean : gvar
117 The prior mean :math:`\\mu`.
118 sigma : float or gvar
119 The error term standard deviation :math:`\\sigma`. If there are
120 weights, the sdev for each unit is obtained dividing ``sigma`` by
121 sqrt(weight).
122 alpha : gvar
123 The numerator of the tree spawn probability :math:`\\alpha` (named
124 ``base`` in BayesTree and BART).
125 beta : gvar
126 The depth exponent of the tree spawn probability :math:`\\beta`
127 (named ``power`` in BayesTree and BART).
128 meansdev : gvar
129 The prior standard deviation :math:`\\lambda` of the latent
130 regression function.
131 fit : empbayes_fit
132 The hyperparameters fit object.
134 Methods
135 -------
136 gp :
137 Create a GP object.
138 data :
139 Creates the dictionary to be passed to `GP.pred` to represent
140 ``y_train``.
141 pred :
142 Evaluate the regression function at given locations.
144 See also
145 --------
146 lsqfitgp.BART
148 """
150 # convert covariates to StructuredArray
151 x_train = self._to_structured(x_train) 1abcde
153 # convert outcomes to 1d array
154 if hasattr(y_train, 'to_numpy'): 1abcde
155 y_train = y_train.to_numpy() 1a
156 y_train = y_train.squeeze() # for dataframes 1a
157 y_train = jnp.asarray(y_train) 1abcde
158 assert y_train.shape == x_train.shape 1abcde
160 # check weights
161 self._no_weights = weights is None 1abcde
162 if self._no_weights: 162 ↛ 164line 162 didn't jump to line 164 because the condition on line 162 was always true1abcde
163 weights = jnp.ones_like(y_train) 1abcde
164 assert weights.shape == y_train.shape 1abcde
166 # prior mean and variance
167 ymin = jnp.min(y_train) 1abcde
168 ymax = jnp.max(y_train) 1abcde
169 mu_mu = (ymax + ymin) / 2 1abcde
170 k_sigma_mu = (ymax - ymin) / 2 1abcde
172 # splitting points and indices
173 splits = _kernels.BART.splits_from_coord(x_train) 1abcde
174 i_train = self._toindices(x_train, splits) 1abcde
176 # prior on hyperparams
177 sigma2_priormean = numpy.mean((y_train - y_train.mean()) ** 2 * weights) 1abcde
178 hyperprior = copula.makedict({ 1abcde
179 'alpha': copula.beta(2, 1), # base of tree gen prob
180 'beta': copula.invgamma(1, 1), # exponent of tree gen prob
181 'log(k)': gvar.gvar(numpy.log(2), 2), # denominator of prior sdev
182 'log(sigma2)': gvar.gvar(numpy.log(sigma2_priormean), 2),
183 # i.i.d. error variance, scaled with weights
184 'mean': gvar.gvar(mu_mu, k_sigma_mu), # mean of the GP
185 })
186 if marginalize_mean: 186 ↛ 190line 186 didn't jump to line 190 because the condition on line 186 was always true1abcde
187 hyperprior.pop('mean') 1abcde
189 # GP factory
190 def makegp(hp, *, i_train, weights, splits, **_): 1abcde
191 kw = dict( 1abcde
192 alpha=hp['alpha'], beta=hp['beta'],
193 maxd=10, reset=[2, 4, 6, 8],
194 )
195 kw.update(kernelkw) 1abcde
196 kernel = _kernels.BART(splits=splits, indices=True, **kw) 1abcde
197 kernel *= (k_sigma_mu / hp['k']) ** 2 1abcde
199 gp = (_GP 1abcde
200 .GP(kernel, checkpos=False, checksym=False, solver='chol')
201 .addx(i_train, 'trainmean')
202 .addcov(jnp.diag(hp['sigma2'] / weights), 'trainnoise')
203 )
204 pieces = {'trainmean': 1, 'trainnoise': 1} 1abcde
205 if 'mean' not in hp: 205 ↛ 208line 205 didn't jump to line 208 because the condition on line 205 was always true1abcde
206 gp = gp.addcov(k_sigma_mu ** 2, 'mean') 1abcde
207 pieces.update({'mean': 1}) 1abcde
208 return gp.addtransf(pieces, 'train') 1abcde
210 # data factory
211 def info(hp, *, mu_mu, **_): 1abcde
212 return {'train': y_train - hp.get('mean', mu_mu)} 1abcde
214 # fit hyperparameters
215 gpkw = dict( 1abcde
216 i_train=i_train,
217 weights=weights,
218 splits=splits,
219 mu_mu=mu_mu,
220 )
221 options = dict( 1abcde
222 verbosity=3,
223 raises=False,
224 minkw=dict(method='l-bfgs-b', options=dict(maxls=4, maxiter=100)),
225 mlkw=dict(epsrel=0),
226 forward=True,
227 gpfactorykw=gpkw,
228 )
229 options.update(fitkw) 1abcde
230 fit = _fit.empbayes_fit(hyperprior, makegp, info, **options) 1abcde
232 # extract hyperparameters from minimization result
233 self.sigma = gvar.sqrt(fit.p['sigma2']) 1abcde
234 self.alpha = fit.p['alpha'] 1abcde
235 self.beta = fit.p['beta'] 1abcde
236 self.meansdev = k_sigma_mu / fit.p['k'] 1abcde
237 self.mean = fit.p.get('mean', mu_mu) 1abcde
239 # set public attributes
240 self.fit = fit 1abcde
242 # set private attributes
243 self._ystd = y_train.std() 1abcde
245 def _gethp(self, hp, rng): 1fabcde
246 if not isinstance(hp, str): 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true1abcde
247 return hp
248 elif hp == 'map': 248 ↛ 250line 248 didn't jump to line 250 because the condition on line 248 was always true1abcde
249 return self.fit.pmean 1abcde
250 elif hp == 'sample':
251 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng)
252 else:
253 raise KeyError(hp)
255 def gp(self, *, hp='map', x_test=None, weights=None, rng=None): 1fabcde
256 """
257 Create a Gaussian process with the fitted hyperparameters.
259 Parameters
260 ----------
261 hp : str or dict
262 The hyperparameters to use. If ``'map'``, use the marginal maximum a
263 posteriori. If ``'sample'``, sample hyperparameters from the
264 posterior. If a dict, use the given hyperparameters.
265 x_test : array or dataframe, optional
266 Additional covariates for "test points".
267 weights : array, optional
268 Weights for the error variance on the test points.
269 rng : numpy.random.Generator, optional
270 Random number generator, used if ``hp == 'sample'``.
272 Returns
273 -------
274 gp : GP
275 A centered Gaussian process object. To add the mean, use the
276 ``mean`` attribute of the `bart` object. The keys of the GP are
277 'Xmean', 'Xnoise', and 'X', where the "X" stands either for 'train'
278 or 'test', and X = Xmean + Xnoise.
279 """
281 hp = self._gethp(hp, rng)
282 return self._gp(hp, x_test, weights, self.fit.gpfactorykw)
284 def _gp(self, hp, x_test, weights, gpfactorykw): 1fabcde
286 # create GP object
287 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde
289 # add test points
290 if x_test is not None: 1abcde
292 # convert covariates to indices
293 x_test = self._to_structured(x_test) 1a
294 i_test = self._toindices(x_test, gpfactorykw['splits']) 1a
295 assert i_test.dtype == gpfactorykw['i_train'].dtype 1a
297 # check weights
298 if weights is not None: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true1a
299 weights = jnp.asarray(weights)
300 assert weights.shape == i_test.shape
301 else:
302 weights = jnp.ones(i_test.shape) 1a
304 # add test points
305 gp = (gp 1a
306 .addx(i_test, 'testmean')
307 .addcov(jnp.diag(hp['sigma2'] / weights), 'testnoise')
308 )
309 pieces = {'testmean': 1, 'testnoise': 1} 1a
310 if 'mean' not in hp: 310 ↛ 312line 310 didn't jump to line 312 because the condition on line 310 was always true1a
311 pieces.update({'mean': 1}) 1a
312 gp = gp.addtransf(pieces, 'test') 1a
314 return gp 1abcde
316 def data(self, *, hp='map', rng=None): 1fabcde
317 """
318 Get the data to be passed to `GP.pred` on a GP object returned by `gp`.
320 Parameters
321 ----------
322 hp : str or dict
323 The hyperparameters to use. If ``'map'``, use the marginal maximum a
324 posteriori. If ``'sample'``, sample hyperparameters from the
325 posterior. If a dict, use the given hyperparameters.
326 rng : numpy.random.Generator, optional
327 Random number generator, used if ``hp == 'sample'``.
329 Returns
330 -------
331 data : dict
332 A dictionary representing ``y_train`` in the format required by the
333 `GP.pred` method.
334 """
336 hp = self._gethp(hp, rng)
337 return self.fit.data(hp, **self.fit.gpfactorykw)
339 def pred(self, *, hp='map', error=False, format='matrices', x_test=None, 1fabcde
340 weights=None, rng=None):
341 """
342 Predict the outcome at given locations.
344 Parameters
345 ----------
346 hp : str or dict
347 The hyperparameters to use. If ``'map'``, use the marginal maximum a
348 posteriori. If ``'sample'``, sample hyperparameters from the
349 posterior. If a dict, use the given hyperparameters.
350 error : bool
351 If ``False`` (default), make a prediction for the latent mean. If
352 ``True``, add the error term.
353 format : {'matrices', 'gvar'}
354 If 'matrices' (default), return the mean and covariance matrix
355 separately. If 'gvar', return an array of gvars.
356 x_test : array or dataframe, optional
357 Covariates for the locations where the prediction is computed. If
358 not specified, predict at the data covariates.
359 weights : array, optional
360 Weights for the error variance on the test points.
361 rng : numpy.random.Generator, optional
362 Random number generator, used if ``hp == 'sample'``.
364 Returns
365 -------
366 If ``format`` is 'matrices' (default):
368 mean, cov : arrays
369 The mean and covariance matrix of the Normal posterior distribution
370 over the regression function at the specified locations.
372 If ``format`` is 'gvar':
374 out : array of `GVar`
375 The same distribution represented as an array of `GVar` objects.
376 """
378 # TODO it is a bit confusing that if x_test=None and error=True, the
379 # prediction returns y_train exactly, instead of hypothetical new
380 # observations at the same covariates.
382 hp = self._gethp(hp, rng) 1abcde
383 if x_test is not None: 1abcde
384 x_test = self._to_structured(x_test) 1a
385 mean, cov = self._pred(hp, x_test, weights, self.fit.gpfactorykw, bool(error)) 1abcde
387 if format == 'gvar': 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true1abcde
388 return gvar.gvar(mean, cov, fast=True)
389 elif format == 'matrices': 389 ↛ 392line 389 didn't jump to line 392 because the condition on line 389 was always true1abcde
390 return mean, cov 1abcde
391 else:
392 raise KeyError(format)
394 @functools.cached_property 1fabcde
395 def _pred(self): 1fabcde
397 @functools.partial(jax.jit, static_argnums=(4,)) 1abcde
398 def _pred(hp, x_test, weights, gpfactorykw, error): 1abcde
399 gp = self._gp(hp, x_test, weights, gpfactorykw) 1abcde
400 data = self.fit.data(hp, **gpfactorykw) 1abcde
401 if x_test is None: 1abcde
402 label = 'train' 1bcde
403 else:
404 label = 'test' 1a
405 if not error: 1abcde
406 label += 'mean' 1bcde
407 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde
408 return outmean + hp.get('mean', gpfactorykw['mu_mu']), outcov 1abcde
410 return _pred 1abcde
412 @classmethod 1fabcde
413 def _to_structured(cls, x): 1fabcde
415 # convert to StructuredArray
416 if hasattr(x, 'columns'): 1abcde
417 x = _array.StructuredArray.from_dataframe(x) 1a
418 elif x.dtype.names is None: 1abcde
419 x = _array.unstructured_to_structured(x) 1bcde
420 else:
421 x = _array.StructuredArray(x) 1a
423 # check
424 assert x.ndim == 1 1abcde
425 def check_numerical(path, dtype): 1abcde
426 if not numpy.issubdtype(dtype, numpy.number): 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true1abcde
427 raise TypeError(f'covariate `{path}` is not numerical')
428 cls._walk_dtype(x.dtype, check_numerical) 1abcde
430 return x 1abcde
432 @classmethod 1fabcde
433 def _walk_dtype(cls, dtype, task, path=None): 1fabcde
434 if dtype.names is None: 1abcde
435 task(path, dtype) 1abcde
436 else:
437 for name in dtype.names: 1abcde
438 subpath = name if path is None else path + ':' + name 1abcde
439 cls._walk_dtype(dtype[name], task, subpath) 1abcde
441 @staticmethod 1fabcde
442 def _toindices(x, splits): 1fabcde
443 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde
444 return _array.unstructured_to_structured(ix, names=x.dtype.names) 1abcde
446 def __repr__(self): 1fabcde
447 out = f"""BART fit: 1a
448alpha = {self.alpha} (0 -> intercept only, 1 -> any)
449beta = {self.beta} (0 -> any, ∞ -> no interactions)
450mean = {self.mean}
451latent sdev = {self.meansdev} (large -> conservative extrapolation)
452data total sdev = {self._ystd:.3g}"""
454 if self._no_weights: 454 ↛ 458line 454 didn't jump to line 458 because the condition on line 454 was always true1a
455 out += f""" 1a
456error sdev = {self.sigma}"""
457 else:
458 weights = numpy.array(self.fit.gpfactorykw['weights'])
459 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights))
460 out += f"""
461error sdev (avg weighted) = {avgsigma}
462error sdev (unweighted) = {self.sigma}"""
464 return out 1a