Coverage for src/lsqfitgp/bayestree/_bart.py: 84%

143 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 13:39 +0000

1# lsqfitgp/bayestree/_bart.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1fabcde

21 

22import numpy 1fabcde

23from jax import numpy as jnp 1fabcde

24import jax 1fabcde

25import gvar 1fabcde

26 

27from .. import copula 1fabcde

28from .. import _kernels 1fabcde

29from .. import _fit 1fabcde

30from .. import _array 1fabcde

31from .. import _GP 1fabcde

32from .. import _fastraniter 1fabcde

33 

34# TODO I added a lot of functionality to bcf. The easiest way to port it over is 

35# adding the option in bcf to drop the second bart model and its associated 

36# hypers, and then write bart as a simple convenience wrapper-subclass over bcf. 

37# (also the option include_pi='none'.) 

38 

39class bart: 1fabcde

40 

41 def __init__(self, 1fabcde

42 x_train, 

43 y_train, 

44 *, 

45 weights=None, 

46 fitkw={}, 

47 kernelkw={}, 

48 marginalize_mean=True, 

49 ): 

50 """ 

51 Nonparametric Bayesian regression with a GP version of BART. 

52 

53 Evaluate a Gaussian process regression with a kernel which accurately 

54 approximates the infinite trees limit of BART. The hyperparameters are 

55 optimized to their marginal MAP. 

56 

57 Parameters 

58 ---------- 

59 x_train : (n, p) array or dataframe 

60 Observed covariates. 

61 y_train : (n,) array 

62 Observed outcomes. 

63 weights : (n,) array 

64 Weights used to rescale the error variance (as 1 / weight). 

65 fitkw : dict 

66 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides 

67 the defaults. 

68 kernelkw : dict 

69 Additional arguments passed to `~lsqfitgp.BART`, overrides the 

70 defaults. 

71 marginalize_mean : bool 

72 If True (default), marginalize the intercept of the model. 

73  

74 Notes 

75 ----- 

76 The regression model is: 

77 

78 .. math:: 

79 y_i &= \\mu + \\lambda f(\\mathbf x_i) + \\varepsilon_i, \\\\ 

80 \\varepsilon_i &\\overset{\\mathrm{i.i.d.}}{\\sim} 

81 N(0, \\sigma^2 / w_i), \\\\ 

82 \\mu &\\sim N( 

83 (\\max(\\mathbf y) + \\min(\\mathbf y)) / 2, 

84 (\\max(\\mathbf y) - \\min(\\mathbf y))^2 / 4 

85 ), \\\\ 

86 \\log \\sigma^2 &\\sim N( 

87 \\log(\\overline{w(y - \\bar y)^2}), 

88 4 

89 ), \\\\ 

90 \\log \\lambda &\\sim N( 

91 \\log ((\\max(\\mathbf y) - \\min(\\mathbf y)) / 4), 

92 4 

93 ), \\\\ 

94 f &\\sim \\mathrm{GP}( 

95 0, 

96 \\mathrm{BART}(\\alpha,\\beta) 

97 ), \\\\ 

98 \\alpha &\\sim \\mathrm{B}(2, 1), \\\\ 

99 \\beta &\\sim \\mathrm{IG}(1, 1). 

100 

101 To make the inference, :math:`(f, \\boldsymbol\\varepsilon, \\mu)` are 

102 marginalized analytically, and the marginal posterior mode of 

103 :math:`(\\sigma, \\lambda, \\alpha, \\beta)` is found by numerical 

104 minimization, after transforming them to express their prior as a 

105 Gaussian copula. Their marginal posterior covariance matrix is estimated 

106 with an approximation of the hessian inverse. See 

107 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize 

108 this procedure. 

109 

110 The tree splitting grid of the BART kernel is set using quantiles of the 

111 observed covariates. This corresponds to settings ``usequants=True``, 

112 ``numcut=inf`` in the R packages BayesTree and BART. Use the 

113 ``kernelkw`` parameter to customize the grid. 

114 

115 Attributes 

116 ---------- 

117 mean : gvar 

118 The prior mean :math:`\\mu`. 

119 sigma : float or gvar 

120 The error term standard deviation :math:`\\sigma`. If there are 

121 weights, the sdev for each unit is obtained dividing ``sigma`` by 

122 sqrt(weight). 

123 alpha : gvar 

124 The numerator of the tree spawn probability :math:`\\alpha` (named 

125 ``base`` in BayesTree and BART). 

126 beta : gvar 

127 The depth exponent of the tree spawn probability :math:`\\beta` 

128 (named ``power`` in BayesTree and BART). 

129 meansdev : gvar 

130 The prior standard deviation :math:`\\lambda` of the latent 

131 regression function. 

132 fit : empbayes_fit 

133 The hyperparameters fit object. 

134 

135 Methods 

136 ------- 

137 gp : 

138 Create a GP object. 

139 data : 

140 Creates the dictionary to be passed to `GP.pred` to represent 

141 ``y_train``. 

142 pred : 

143 Evaluate the regression function at given locations. 

144 

145 See also 

146 -------- 

147 lsqfitgp.BART 

148  

149 """ 

150 

151 # convert covariates to StructuredArray 

152 x_train = self._to_structured(x_train) 1abcde

153 

154 # convert outcomes to 1d array 

155 if hasattr(y_train, 'to_numpy'): 1abcde

156 y_train = y_train.to_numpy() 1a

157 y_train = y_train.squeeze() # for dataframes 1a

158 y_train = jnp.asarray(y_train) 1abcde

159 assert y_train.shape == x_train.shape 1abcde

160 

161 # check weights 

162 self._no_weights = weights is None 1abcde

163 if self._no_weights: 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was always true1abcde

164 weights = jnp.ones_like(y_train) 1abcde

165 assert weights.shape == y_train.shape 1abcde

166 

167 # prior mean and variance 

168 ymin = jnp.min(y_train) 1abcde

169 ymax = jnp.max(y_train) 1abcde

170 mu_mu = (ymax + ymin) / 2 1abcde

171 k_sigma_mu = (ymax - ymin) / 2 1abcde

172 

173 # splitting points and indices 

174 splits = _kernels.BART.splits_from_coord(x_train) 1abcde

175 i_train = self._toindices(x_train, splits) 1abcde

176 

177 # prior on hyperparams 

178 sigma2_priormean = numpy.mean((y_train - y_train.mean()) ** 2 * weights) 1abcde

179 hyperprior = copula.makedict({ 1abcde

180 'alpha': copula.beta(2, 1), # base of tree gen prob 

181 'beta': copula.invgamma(1, 1), # exponent of tree gen prob 

182 'log(k)': gvar.gvar(numpy.log(2), 2), # denominator of prior sdev 

183 'log(sigma2)': gvar.gvar(numpy.log(sigma2_priormean), 2), 

184 # i.i.d. error variance, scaled with weights 

185 'mean': gvar.gvar(mu_mu, k_sigma_mu), # mean of the GP 

186 }) 

187 if marginalize_mean: 187 ↛ 191line 187 didn't jump to line 191 because the condition on line 187 was always true1abcde

188 hyperprior.pop('mean') 1abcde

189 

190 # GP factory 

191 def makegp(hp, *, i_train, weights, splits, **_): 1abcde

192 kw = dict( 1abcde

193 alpha=hp['alpha'], beta=hp['beta'], 

194 maxd=10, reset=[2, 4, 6, 8], 

195 ) 

196 kw.update(kernelkw) 1abcde

197 kernel = _kernels.BART(splits=splits, indices=True, **kw) 1abcde

198 kernel *= (k_sigma_mu / hp['k']) ** 2 1abcde

199 

200 gp = (_GP 1abcde

201 .GP(kernel, checkpos=False, checksym=False, solver='chol') 

202 .addx(i_train, 'trainmean') 

203 .addcov(jnp.diag(hp['sigma2'] / weights), 'trainnoise') 

204 ) 

205 pieces = {'trainmean': 1, 'trainnoise': 1} 1abcde

206 if 'mean' not in hp: 206 ↛ 209line 206 didn't jump to line 209 because the condition on line 206 was always true1abcde

207 gp = gp.addcov(k_sigma_mu ** 2, 'mean') 1abcde

208 pieces.update({'mean': 1}) 1abcde

209 return gp.addtransf(pieces, 'train') 1abcde

210 

211 # data factory 

212 def info(hp, *, mu_mu, **_): 1abcde

213 return {'train': y_train - hp.get('mean', mu_mu)} 1abcde

214 

215 # fit hyperparameters 

216 gpkw = dict( 1abcde

217 i_train=i_train, 

218 weights=weights, 

219 splits=splits, 

220 mu_mu=mu_mu, 

221 ) 

222 options = dict( 1abcde

223 verbosity=3, 

224 raises=False, 

225 minkw=dict(method='l-bfgs-b', options=dict(maxls=4, maxiter=100)), 

226 mlkw=dict(epsrel=0), 

227 forward=True, 

228 gpfactorykw=gpkw, 

229 ) 

230 options.update(fitkw) 1abcde

231 fit = _fit.empbayes_fit(hyperprior, makegp, info, **options) 1abcde

232 

233 # extract hyperparameters from minimization result 

234 self.sigma = gvar.sqrt(fit.p['sigma2']) 1abcde

235 self.alpha = fit.p['alpha'] 1abcde

236 self.beta = fit.p['beta'] 1abcde

237 self.meansdev = k_sigma_mu / fit.p['k'] 1abcde

238 self.mean = fit.p.get('mean', mu_mu) 1abcde

239 

240 # set public attributes 

241 self.fit = fit 1abcde

242 

243 # set private attributes 

244 self._ystd = y_train.std() 1abcde

245 

246 def _gethp(self, hp, rng): 1fabcde

247 if not isinstance(hp, str): 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true1abcde

248 return hp 

249 elif hp == 'map': 249 ↛ 251line 249 didn't jump to line 251 because the condition on line 249 was always true1abcde

250 return self.fit.pmean 1abcde

251 elif hp == 'sample': 

252 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng) 

253 else: 

254 raise KeyError(hp) 

255 

256 def gp(self, *, hp='map', x_test=None, weights=None, rng=None): 1fabcde

257 """ 

258 Create a Gaussian process with the fitted hyperparameters. 

259 

260 Parameters 

261 ---------- 

262 hp : str or dict 

263 The hyperparameters to use. If ``'map'``, use the marginal maximum a 

264 posteriori. If ``'sample'``, sample hyperparameters from the 

265 posterior. If a dict, use the given hyperparameters. 

266 x_test : array or dataframe, optional 

267 Additional covariates for "test points". 

268 weights : array, optional 

269 Weights for the error variance on the test points. 

270 rng : numpy.random.Generator, optional 

271 Random number generator, used if ``hp == 'sample'``. 

272 

273 Returns 

274 ------- 

275 gp : GP 

276 A centered Gaussian process object. To add the mean, use the 

277 ``mean`` attribute of the `bart` object. The keys of the GP are 

278 'Xmean', 'Xnoise', and 'X', where the "X" stands either for 'train' 

279 or 'test', and X = Xmean + Xnoise. 

280 """ 

281 

282 hp = self._gethp(hp, rng) 

283 return self._gp(hp, x_test, weights, self.fit.gpfactorykw) 

284 

285 def _gp(self, hp, x_test, weights, gpfactorykw): 1fabcde

286 

287 # create GP object 

288 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde

289 

290 # add test points 

291 if x_test is not None: 1abcde

292 

293 # convert covariates to indices 

294 x_test = self._to_structured(x_test) 1a

295 i_test = self._toindices(x_test, gpfactorykw['splits']) 1a

296 assert i_test.dtype == gpfactorykw['i_train'].dtype 1a

297 

298 # check weights 

299 if weights is not None: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true1a

300 weights = jnp.asarray(weights) 

301 assert weights.shape == i_test.shape 

302 else: 

303 weights = jnp.ones(i_test.shape) 1a

304 

305 # add test points 

306 gp = (gp 1a

307 .addx(i_test, 'testmean') 

308 .addcov(jnp.diag(hp['sigma2'] / weights), 'testnoise') 

309 ) 

310 pieces = {'testmean': 1, 'testnoise': 1} 1a

311 if 'mean' not in hp: 311 ↛ 313line 311 didn't jump to line 313 because the condition on line 311 was always true1a

312 pieces.update({'mean': 1}) 1a

313 gp = gp.addtransf(pieces, 'test') 1a

314 

315 return gp 1abcde

316 

317 def data(self, *, hp='map', rng=None): 1fabcde

318 """ 

319 Get the data to be passed to `GP.pred` on a GP object returned by `gp`. 

320 

321 Parameters 

322 ---------- 

323 hp : str or dict 

324 The hyperparameters to use. If ``'map'``, use the marginal maximum a 

325 posteriori. If ``'sample'``, sample hyperparameters from the 

326 posterior. If a dict, use the given hyperparameters. 

327 rng : numpy.random.Generator, optional 

328 Random number generator, used if ``hp == 'sample'``. 

329 

330 Returns 

331 ------- 

332 data : dict 

333 A dictionary representing ``y_train`` in the format required by the 

334 `GP.pred` method. 

335 """ 

336 

337 hp = self._gethp(hp, rng) 

338 return self.fit.data(hp, **self.fit.gpfactorykw) 

339 

340 def pred(self, *, hp='map', error=False, format='matrices', x_test=None, 1fabcde

341 weights=None, rng=None): 

342 """ 

343 Predict the outcome at given locations. 

344 

345 Parameters 

346 ---------- 

347 hp : str or dict 

348 The hyperparameters to use. If ``'map'``, use the marginal maximum a 

349 posteriori. If ``'sample'``, sample hyperparameters from the 

350 posterior. If a dict, use the given hyperparameters. 

351 error : bool 

352 If ``False`` (default), make a prediction for the latent mean. If 

353 ``True``, add the error term.  

354 format : {'matrices', 'gvar'} 

355 If 'matrices' (default), return the mean and covariance matrix 

356 separately. If 'gvar', return an array of gvars. 

357 x_test : array or dataframe, optional 

358 Covariates for the locations where the prediction is computed. If 

359 not specified, predict at the data covariates. 

360 weights : array, optional 

361 Weights for the error variance on the test points. 

362 rng : numpy.random.Generator, optional 

363 Random number generator, used if ``hp == 'sample'``. 

364 

365 Returns 

366 ------- 

367 If ``format`` is 'matrices' (default): 

368 

369 mean, cov : arrays 

370 The mean and covariance matrix of the Normal posterior distribution 

371 over the regression function at the specified locations. 

372 

373 If ``format`` is 'gvar': 

374 

375 out : array of `GVar` 

376 The same distribution represented as an array of `GVar` objects. 

377 """ 

378 

379 # TODO it is a bit confusing that if x_test=None and error=True, the 

380 # prediction returns y_train exactly, instead of hypothetical new 

381 # observations at the same covariates. 

382 

383 hp = self._gethp(hp, rng) 1abcde

384 if x_test is not None: 1abcde

385 x_test = self._to_structured(x_test) 1a

386 mean, cov = self._pred(hp, x_test, weights, self.fit.gpfactorykw, bool(error)) 1abcde

387 

388 if format == 'gvar': 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true1abcde

389 return gvar.gvar(mean, cov, fast=True) 

390 elif format == 'matrices': 390 ↛ 393line 390 didn't jump to line 393 because the condition on line 390 was always true1abcde

391 return mean, cov 1abcde

392 else: 

393 raise KeyError(format) 

394 

395 @functools.cached_property 1fabcde

396 def _pred(self): 1fabcde

397 

398 @functools.partial(jax.jit, static_argnums=(4,)) 1abcde

399 def _pred(hp, x_test, weights, gpfactorykw, error): 1abcde

400 gp = self._gp(hp, x_test, weights, gpfactorykw) 1abcde

401 data = self.fit.data(hp, **gpfactorykw) 1abcde

402 if x_test is None: 1abcde

403 label = 'train' 1bcde

404 else: 

405 label = 'test' 1a

406 if not error: 1abcde

407 label += 'mean' 1bcde

408 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde

409 return outmean + hp.get('mean', gpfactorykw['mu_mu']), outcov 1abcde

410 

411 return _pred 1abcde

412 

413 @classmethod 1fabcde

414 def _to_structured(cls, x): 1fabcde

415 

416 # convert to StructuredArray 

417 if hasattr(x, 'columns'): 1abcde

418 x = _array.StructuredArray.from_dataframe(x) 1a

419 elif x.dtype.names is None: 1abcde

420 x = _array.unstructured_to_structured(x) 1bcde

421 else: 

422 x = _array.StructuredArray(x) 1a

423 

424 # check 

425 assert x.ndim == 1 1abcde

426 def check_numerical(path, dtype): 1abcde

427 if not numpy.issubdtype(dtype, numpy.number): 427 ↛ 428line 427 didn't jump to line 428 because the condition on line 427 was never true1abcde

428 raise TypeError(f'covariate `{path}` is not numerical') 

429 cls._walk_dtype(x.dtype, check_numerical) 1abcde

430 

431 return x 1abcde

432 

433 @classmethod 1fabcde

434 def _walk_dtype(cls, dtype, task, path=None): 1fabcde

435 if dtype.names is None: 1abcde

436 task(path, dtype) 1abcde

437 else: 

438 for name in dtype.names: 1abcde

439 subpath = name if path is None else path + ':' + name 1abcde

440 cls._walk_dtype(dtype[name], task, subpath) 1abcde

441 

442 @staticmethod 1fabcde

443 def _toindices(x, splits): 1fabcde

444 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde

445 return _array.unstructured_to_structured(ix, names=x.dtype.names) 1abcde

446 

447 def __repr__(self): 1fabcde

448 out = f"""BART fit: 1a

449alpha = {self.alpha} (0 -> intercept only, 1 -> any) 

450beta = {self.beta} (0 -> any, ∞ -> no interactions) 

451mean = {self.mean} 

452latent sdev = {self.meansdev} (large -> conservative extrapolation) 

453data total sdev = {self._ystd:.3g}""" 

454 

455 if self._no_weights: 455 ↛ 459line 455 didn't jump to line 459 because the condition on line 455 was always true1a

456 out += f""" 1a

457error sdev = {self.sigma}""" 

458 else: 

459 weights = numpy.array(self.fit.gpfactorykw['weights']) 

460 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights)) 

461 out += f""" 

462error sdev (avg weighted) = {avgsigma} 

463error sdev (unweighted) = {self.sigma}""" 

464 

465 return out 1a