Coverage for src/lsqfitgp/bayestree/_bart.py: 84%

143 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/bayestree/_bart.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1fabcde

21 

22import numpy 1fabcde

23from jax import numpy as jnp 1fabcde

24import jax 1fabcde

25import gvar 1fabcde

26 

27from .. import copula 1fabcde

28from .. import _kernels 1fabcde

29from .. import _fit 1fabcde

30from .. import _array 1fabcde

31from .. import _GP 1fabcde

32from .. import _fastraniter 1fabcde

33 

34# TODO I added a lot of functionality to bcf. The easiest way to port it over is 

35# adding the option in bcf to drop the second bart model and its associated 

36# hypers, and then write bart as a simple convenience wrapper-subclass over bcf. 

37 

38class bart: 1fabcde

39 

40 def __init__(self, 1fabcde

41 x_train, 

42 y_train, 

43 *, 

44 weights=None, 

45 fitkw={}, 

46 kernelkw={}, 

47 marginalize_mean=True, 

48 ): 

49 """ 

50 Nonparametric Bayesian regression with a GP version of BART. 

51 

52 Evaluate a Gaussian process regression with a kernel which accurately 

53 approximates the infinite trees limit of BART. The hyperparameters are 

54 optimized to their marginal MAP. 

55 

56 Parameters 

57 ---------- 

58 x_train : (n, p) array or dataframe 

59 Observed covariates. 

60 y_train : (n,) array 

61 Observed outcomes. 

62 weights : (n,) array 

63 Weights used to rescale the error variance (as 1 / weight). 

64 fitkw : dict 

65 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides 

66 the defaults. 

67 kernelkw : dict 

68 Additional arguments passed to `~lsqfitgp.BART`, overrides the 

69 defaults. 

70 marginalize_mean : bool 

71 If True (default), marginalize the intercept of the model. 

72  

73 Notes 

74 ----- 

75 The regression model is: 

76 

77 .. math:: 

78 y_i &= \\mu + \\lambda f(\\mathbf x_i) + \\varepsilon_i, \\\\ 

79 \\varepsilon_i &\\overset{\\mathrm{i.i.d.}}{\\sim} 

80 N(0, \\sigma^2 / w_i), \\\\ 

81 \\mu &\\sim N( 

82 (\\max(\\mathbf y) + \\min(\\mathbf y)) / 2, 

83 (\\max(\\mathbf y) - \\min(\\mathbf y))^2 / 4 

84 ), \\\\ 

85 \\log \\sigma^2 &\\sim N( 

86 \\log(\\overline{w(y - \\bar y)^2}), 

87 4 

88 ), \\\\ 

89 \\log \\lambda &\\sim N( 

90 \\log ((\\max(\\mathbf y) - \\min(\\mathbf y)) / 4), 

91 4 

92 ), \\\\ 

93 f &\\sim \\mathrm{GP}( 

94 0, 

95 \\mathrm{BART}(\\alpha,\\beta) 

96 ), \\\\ 

97 \\alpha &\\sim \\mathrm{B}(2, 1), \\\\ 

98 \\beta &\\sim \\mathrm{IG}(1, 1). 

99 

100 To make the inference, :math:`(f, \\boldsymbol\\varepsilon, \\mu)` are 

101 marginalized analytically, and the marginal posterior mode of 

102 :math:`(\\sigma, \\lambda, \\alpha, \\beta)` is found by numerical 

103 minimization, after transforming them to express their prior as a 

104 Gaussian copula. Their marginal posterior covariance matrix is estimated 

105 with an approximation of the hessian inverse. See 

106 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize 

107 this procedure. 

108 

109 The tree splitting grid of the BART kernel is set using quantiles of the 

110 observed covariates. This corresponds to settings ``usequants=True``, 

111 ``numcut=inf`` in the R packages BayesTree and BART. Use the 

112 ``kernelkw`` parameter to customize the grid. 

113 

114 Attributes 

115 ---------- 

116 mean : gvar 

117 The prior mean :math:`\\mu`. 

118 sigma : float or gvar 

119 The error term standard deviation :math:`\\sigma`. If there are 

120 weights, the sdev for each unit is obtained dividing ``sigma`` by 

121 sqrt(weight). 

122 alpha : gvar 

123 The numerator of the tree spawn probability :math:`\\alpha` (named 

124 ``base`` in BayesTree and BART). 

125 beta : gvar 

126 The depth exponent of the tree spawn probability :math:`\\beta` 

127 (named ``power`` in BayesTree and BART). 

128 meansdev : gvar 

129 The prior standard deviation :math:`\\lambda` of the latent 

130 regression function. 

131 fit : empbayes_fit 

132 The hyperparameters fit object. 

133 

134 Methods 

135 ------- 

136 gp : 

137 Create a GP object. 

138 data : 

139 Creates the dictionary to be passed to `GP.pred` to represent 

140 ``y_train``. 

141 pred : 

142 Evaluate the regression function at given locations. 

143 

144 See also 

145 -------- 

146 lsqfitgp.BART 

147  

148 """ 

149 

150 # convert covariates to StructuredArray 

151 x_train = self._to_structured(x_train) 1abcde

152 

153 # convert outcomes to 1d array 

154 if hasattr(y_train, 'to_numpy'): 1abcde

155 y_train = y_train.to_numpy() 1a

156 y_train = y_train.squeeze() # for dataframes 1a

157 y_train = jnp.asarray(y_train) 1abcde

158 assert y_train.shape == x_train.shape 1abcde

159 

160 # check weights 

161 self._no_weights = weights is None 1abcde

162 if self._no_weights: 162 ↛ 164line 162 didn't jump to line 164 because the condition on line 162 was always true1abcde

163 weights = jnp.ones_like(y_train) 1abcde

164 assert weights.shape == y_train.shape 1abcde

165 

166 # prior mean and variance 

167 ymin = jnp.min(y_train) 1abcde

168 ymax = jnp.max(y_train) 1abcde

169 mu_mu = (ymax + ymin) / 2 1abcde

170 k_sigma_mu = (ymax - ymin) / 2 1abcde

171 

172 # splitting points and indices 

173 splits = _kernels.BART.splits_from_coord(x_train) 1abcde

174 i_train = self._toindices(x_train, splits) 1abcde

175 

176 # prior on hyperparams 

177 sigma2_priormean = numpy.mean((y_train - y_train.mean()) ** 2 * weights) 1abcde

178 hyperprior = copula.makedict({ 1abcde

179 'alpha': copula.beta(2, 1), # base of tree gen prob 

180 'beta': copula.invgamma(1, 1), # exponent of tree gen prob 

181 'log(k)': gvar.gvar(numpy.log(2), 2), # denominator of prior sdev 

182 'log(sigma2)': gvar.gvar(numpy.log(sigma2_priormean), 2), 

183 # i.i.d. error variance, scaled with weights 

184 'mean': gvar.gvar(mu_mu, k_sigma_mu), # mean of the GP 

185 }) 

186 if marginalize_mean: 186 ↛ 190line 186 didn't jump to line 190 because the condition on line 186 was always true1abcde

187 hyperprior.pop('mean') 1abcde

188 

189 # GP factory 

190 def makegp(hp, *, i_train, weights, splits, **_): 1abcde

191 kw = dict( 1abcde

192 alpha=hp['alpha'], beta=hp['beta'], 

193 maxd=10, reset=[2, 4, 6, 8], 

194 ) 

195 kw.update(kernelkw) 1abcde

196 kernel = _kernels.BART(splits=splits, indices=True, **kw) 1abcde

197 kernel *= (k_sigma_mu / hp['k']) ** 2 1abcde

198 

199 gp = (_GP 1abcde

200 .GP(kernel, checkpos=False, checksym=False, solver='chol') 

201 .addx(i_train, 'trainmean') 

202 .addcov(jnp.diag(hp['sigma2'] / weights), 'trainnoise') 

203 ) 

204 pieces = {'trainmean': 1, 'trainnoise': 1} 1abcde

205 if 'mean' not in hp: 205 ↛ 208line 205 didn't jump to line 208 because the condition on line 205 was always true1abcde

206 gp = gp.addcov(k_sigma_mu ** 2, 'mean') 1abcde

207 pieces.update({'mean': 1}) 1abcde

208 return gp.addtransf(pieces, 'train') 1abcde

209 

210 # data factory 

211 def info(hp, *, mu_mu, **_): 1abcde

212 return {'train': y_train - hp.get('mean', mu_mu)} 1abcde

213 

214 # fit hyperparameters 

215 gpkw = dict( 1abcde

216 i_train=i_train, 

217 weights=weights, 

218 splits=splits, 

219 mu_mu=mu_mu, 

220 ) 

221 options = dict( 1abcde

222 verbosity=3, 

223 raises=False, 

224 minkw=dict(method='l-bfgs-b', options=dict(maxls=4, maxiter=100)), 

225 mlkw=dict(epsrel=0), 

226 forward=True, 

227 gpfactorykw=gpkw, 

228 ) 

229 options.update(fitkw) 1abcde

230 fit = _fit.empbayes_fit(hyperprior, makegp, info, **options) 1abcde

231 

232 # extract hyperparameters from minimization result 

233 self.sigma = gvar.sqrt(fit.p['sigma2']) 1abcde

234 self.alpha = fit.p['alpha'] 1abcde

235 self.beta = fit.p['beta'] 1abcde

236 self.meansdev = k_sigma_mu / fit.p['k'] 1abcde

237 self.mean = fit.p.get('mean', mu_mu) 1abcde

238 

239 # set public attributes 

240 self.fit = fit 1abcde

241 

242 # set private attributes 

243 self._ystd = y_train.std() 1abcde

244 

245 def _gethp(self, hp, rng): 1fabcde

246 if not isinstance(hp, str): 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true1abcde

247 return hp 

248 elif hp == 'map': 248 ↛ 250line 248 didn't jump to line 250 because the condition on line 248 was always true1abcde

249 return self.fit.pmean 1abcde

250 elif hp == 'sample': 

251 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng) 

252 else: 

253 raise KeyError(hp) 

254 

255 def gp(self, *, hp='map', x_test=None, weights=None, rng=None): 1fabcde

256 """ 

257 Create a Gaussian process with the fitted hyperparameters. 

258 

259 Parameters 

260 ---------- 

261 hp : str or dict 

262 The hyperparameters to use. If ``'map'``, use the marginal maximum a 

263 posteriori. If ``'sample'``, sample hyperparameters from the 

264 posterior. If a dict, use the given hyperparameters. 

265 x_test : array or dataframe, optional 

266 Additional covariates for "test points". 

267 weights : array, optional 

268 Weights for the error variance on the test points. 

269 rng : numpy.random.Generator, optional 

270 Random number generator, used if ``hp == 'sample'``. 

271 

272 Returns 

273 ------- 

274 gp : GP 

275 A centered Gaussian process object. To add the mean, use the 

276 ``mean`` attribute of the `bart` object. The keys of the GP are 

277 'Xmean', 'Xnoise', and 'X', where the "X" stands either for 'train' 

278 or 'test', and X = Xmean + Xnoise. 

279 """ 

280 

281 hp = self._gethp(hp, rng) 

282 return self._gp(hp, x_test, weights, self.fit.gpfactorykw) 

283 

284 def _gp(self, hp, x_test, weights, gpfactorykw): 1fabcde

285 

286 # create GP object 

287 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde

288 

289 # add test points 

290 if x_test is not None: 1abcde

291 

292 # convert covariates to indices 

293 x_test = self._to_structured(x_test) 1a

294 i_test = self._toindices(x_test, gpfactorykw['splits']) 1a

295 assert i_test.dtype == gpfactorykw['i_train'].dtype 1a

296 

297 # check weights 

298 if weights is not None: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true1a

299 weights = jnp.asarray(weights) 

300 assert weights.shape == i_test.shape 

301 else: 

302 weights = jnp.ones(i_test.shape) 1a

303 

304 # add test points 

305 gp = (gp 1a

306 .addx(i_test, 'testmean') 

307 .addcov(jnp.diag(hp['sigma2'] / weights), 'testnoise') 

308 ) 

309 pieces = {'testmean': 1, 'testnoise': 1} 1a

310 if 'mean' not in hp: 310 ↛ 312line 310 didn't jump to line 312 because the condition on line 310 was always true1a

311 pieces.update({'mean': 1}) 1a

312 gp = gp.addtransf(pieces, 'test') 1a

313 

314 return gp 1abcde

315 

316 def data(self, *, hp='map', rng=None): 1fabcde

317 """ 

318 Get the data to be passed to `GP.pred` on a GP object returned by `gp`. 

319 

320 Parameters 

321 ---------- 

322 hp : str or dict 

323 The hyperparameters to use. If ``'map'``, use the marginal maximum a 

324 posteriori. If ``'sample'``, sample hyperparameters from the 

325 posterior. If a dict, use the given hyperparameters. 

326 rng : numpy.random.Generator, optional 

327 Random number generator, used if ``hp == 'sample'``. 

328 

329 Returns 

330 ------- 

331 data : dict 

332 A dictionary representing ``y_train`` in the format required by the 

333 `GP.pred` method. 

334 """ 

335 

336 hp = self._gethp(hp, rng) 

337 return self.fit.data(hp, **self.fit.gpfactorykw) 

338 

339 def pred(self, *, hp='map', error=False, format='matrices', x_test=None, 1fabcde

340 weights=None, rng=None): 

341 """ 

342 Predict the outcome at given locations. 

343 

344 Parameters 

345 ---------- 

346 hp : str or dict 

347 The hyperparameters to use. If ``'map'``, use the marginal maximum a 

348 posteriori. If ``'sample'``, sample hyperparameters from the 

349 posterior. If a dict, use the given hyperparameters. 

350 error : bool 

351 If ``False`` (default), make a prediction for the latent mean. If 

352 ``True``, add the error term.  

353 format : {'matrices', 'gvar'} 

354 If 'matrices' (default), return the mean and covariance matrix 

355 separately. If 'gvar', return an array of gvars. 

356 x_test : array or dataframe, optional 

357 Covariates for the locations where the prediction is computed. If 

358 not specified, predict at the data covariates. 

359 weights : array, optional 

360 Weights for the error variance on the test points. 

361 rng : numpy.random.Generator, optional 

362 Random number generator, used if ``hp == 'sample'``. 

363 

364 Returns 

365 ------- 

366 If ``format`` is 'matrices' (default): 

367 

368 mean, cov : arrays 

369 The mean and covariance matrix of the Normal posterior distribution 

370 over the regression function at the specified locations. 

371 

372 If ``format`` is 'gvar': 

373 

374 out : array of `GVar` 

375 The same distribution represented as an array of `GVar` objects. 

376 """ 

377 

378 # TODO it is a bit confusing that if x_test=None and error=True, the 

379 # prediction returns y_train exactly, instead of hypothetical new 

380 # observations at the same covariates. 

381 

382 hp = self._gethp(hp, rng) 1abcde

383 if x_test is not None: 1abcde

384 x_test = self._to_structured(x_test) 1a

385 mean, cov = self._pred(hp, x_test, weights, self.fit.gpfactorykw, bool(error)) 1abcde

386 

387 if format == 'gvar': 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true1abcde

388 return gvar.gvar(mean, cov, fast=True) 

389 elif format == 'matrices': 389 ↛ 392line 389 didn't jump to line 392 because the condition on line 389 was always true1abcde

390 return mean, cov 1abcde

391 else: 

392 raise KeyError(format) 

393 

394 @functools.cached_property 1fabcde

395 def _pred(self): 1fabcde

396 

397 @functools.partial(jax.jit, static_argnums=(4,)) 1abcde

398 def _pred(hp, x_test, weights, gpfactorykw, error): 1abcde

399 gp = self._gp(hp, x_test, weights, gpfactorykw) 1abcde

400 data = self.fit.data(hp, **gpfactorykw) 1abcde

401 if x_test is None: 1abcde

402 label = 'train' 1bcde

403 else: 

404 label = 'test' 1a

405 if not error: 1abcde

406 label += 'mean' 1bcde

407 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde

408 return outmean + hp.get('mean', gpfactorykw['mu_mu']), outcov 1abcde

409 

410 return _pred 1abcde

411 

412 @classmethod 1fabcde

413 def _to_structured(cls, x): 1fabcde

414 

415 # convert to StructuredArray 

416 if hasattr(x, 'columns'): 1abcde

417 x = _array.StructuredArray.from_dataframe(x) 1a

418 elif x.dtype.names is None: 1abcde

419 x = _array.unstructured_to_structured(x) 1bcde

420 else: 

421 x = _array.StructuredArray(x) 1a

422 

423 # check 

424 assert x.ndim == 1 1abcde

425 def check_numerical(path, dtype): 1abcde

426 if not numpy.issubdtype(dtype, numpy.number): 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true1abcde

427 raise TypeError(f'covariate `{path}` is not numerical') 

428 cls._walk_dtype(x.dtype, check_numerical) 1abcde

429 

430 return x 1abcde

431 

432 @classmethod 1fabcde

433 def _walk_dtype(cls, dtype, task, path=None): 1fabcde

434 if dtype.names is None: 1abcde

435 task(path, dtype) 1abcde

436 else: 

437 for name in dtype.names: 1abcde

438 subpath = name if path is None else path + ':' + name 1abcde

439 cls._walk_dtype(dtype[name], task, subpath) 1abcde

440 

441 @staticmethod 1fabcde

442 def _toindices(x, splits): 1fabcde

443 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde

444 return _array.unstructured_to_structured(ix, names=x.dtype.names) 1abcde

445 

446 def __repr__(self): 1fabcde

447 out = f"""BART fit: 1a

448alpha = {self.alpha} (0 -> intercept only, 1 -> any) 

449beta = {self.beta} (0 -> any, ∞ -> no interactions) 

450mean = {self.mean} 

451latent sdev = {self.meansdev} (large -> conservative extrapolation) 

452data total sdev = {self._ystd:.3g}""" 

453 

454 if self._no_weights: 454 ↛ 458line 454 didn't jump to line 458 because the condition on line 454 was always true1a

455 out += f""" 1a

456error sdev = {self.sigma}""" 

457 else: 

458 weights = numpy.array(self.fit.gpfactorykw['weights']) 

459 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights)) 

460 out += f""" 

461error sdev (avg weighted) = {avgsigma} 

462error sdev (unweighted) = {self.sigma}""" 

463 

464 return out 1a