Coverage for src/bartz/BART.py: 76%

205 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 18:54 +0000

1# bartz/src/bartz/BART.py 

2# 

3# Copyright (c) 2024, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25import functools 1a

26 

27import jax 1a

28import jax.numpy as jnp 1a

29 

30from . import jaxext 1a

31from . import grove 1a

32from . import mcmcstep 1a

33from . import mcmcloop 1a

34from . import prepcovars 1a

35 

36class gbart: 1a

37 """ 

38 Nonparametric regression with Bayesian Additive Regression Trees (BART). 

39 

40 Regress `y_train` on `x_train` with a latent mean function represented as 

41 a sum of decision trees. The inference is carried out by sampling the 

42 posterior distribution of the tree ensemble with an MCMC. 

43 

44 Parameters 

45 ---------- 

46 x_train : array (p, n) or DataFrame 

47 The training predictors. 

48 y_train : array (n,) or Series 

49 The training responses. 

50 x_test : array (p, m) or DataFrame, optional 

51 The test predictors. 

52 usequants : bool, default False 

53 Whether to use predictors quantiles instead of a uniform grid to bin 

54 predictors. 

55 sigest : float, optional 

56 An estimate of the residual standard deviation on `y_train`, used to 

57 set `lamda`. If not specified, it is estimated by linear regression. 

58 If `y_train` has less than two elements, it is set to 1. If n <= p, it 

59 is set to the variance of `y_train`. Ignored if `lamda` is specified. 

60 sigdf : int, default 3 

61 The degrees of freedom of the scaled inverse-chisquared prior on the 

62 noise variance. 

63 sigquant : float, default 0.9 

64 The quantile of the prior on the noise variance that shall match 

65 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

66 k : float, default 2 

67 The inverse scale of the prior standard deviation on the latent mean 

68 function, relative to half the observed range of `y_train`. If `y_train` 

69 has less than two elements, `k` is ignored and the scale is set to 1. 

70 power : float, default 2 

71 base : float, default 0.95 

72 Parameters of the prior on tree node generation. The probability that a 

73 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

74 power``. 

75 maxdepth : int, default 6 

76 The maximum depth of the trees. This is 1-based, so with the default 

77 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

78 lamda : float, optional 

79 The scale of the prior on the noise variance. If ``lamda==1``, the 

80 prior is an inverse chi-squared scaled to have harmonic mean 1. If 

81 not specified, it is set based on `sigest` and `sigquant`. 

82 offset : float, optional 

83 The prior mean of the latent mean function. If not specified, it is set 

84 to the mean of `y_train`. If `y_train` is empty, it is set to 0. 

85 ntree : int, default 200 

86 The number of trees used to represent the latent mean function. 

87 numcut : int, default 255 

88 If `usequants` is `False`: the exact number of cutpoints used to bin the 

89 predictors, ranging between the minimum and maximum observed values 

90 (excluded). 

91 

92 If `usequants` is `True`: the maximum number of cutpoints to use for 

93 binning the predictors. Each predictor is binned such that its 

94 distribution in `x_train` is approximately uniform across bins. The 

95 number of bins is at most the number of unique values appearing in 

96 `x_train`, or ``numcut + 1``. 

97 

98 Before running the algorithm, the predictors are compressed to the 

99 smallest integer type that fits the bin indices, so `numcut` is best set 

100 to the maximum value of an unsigned integer type. 

101 ndpost : int, default 1000 

102 The number of MCMC samples to save, after burn-in. 

103 nskip : int, default 100 

104 The number of initial MCMC samples to discard as burn-in. 

105 keepevery : int, default 1 

106 The thinning factor for the MCMC samples, after burn-in. 

107 printevery : int, default 100 

108 The number of iterations (including skipped ones) between each log. 

109 seed : int or jax random key, default 0 

110 The seed for the random number generator. 

111 

112 Attributes 

113 ---------- 

114 yhat_train : array (ndpost, n) 

115 The conditional posterior mean at `x_train` for each MCMC iteration. 

116 yhat_train_mean : array (n,) 

117 The marginal posterior mean at `x_train`. 

118 yhat_test : array (ndpost, m) 

119 The conditional posterior mean at `x_test` for each MCMC iteration. 

120 yhat_test_mean : array (m,) 

121 The marginal posterior mean at `x_test`. 

122 sigma : array (ndpost,) 

123 The standard deviation of the error. 

124 first_sigma : array (nskip,) 

125 The standard deviation of the error in the burn-in phase. 

126 offset : float 

127 The prior mean of the latent mean function. 

128 scale : float 

129 The prior standard deviation of the latent mean function. 

130 lamda : float 

131 The prior harmonic mean of the error variance. 

132 sigest : float or None 

133 The estimated standard deviation of the error used to set `lamda`. 

134 ntree : int 

135 The number of trees. 

136 maxdepth : int 

137 The maximum depth of the trees. 

138 initkw : dict 

139 Additional arguments passed to `mcmcstep.init`. 

140 

141 Methods 

142 ------- 

143 predict 

144 

145 Notes 

146 ----- 

147 This interface imitates the function ``gbart`` from the R package `BART 

148 <https://cran.r-project.org/package=BART>`_, but with these differences: 

149 

150 - If `x_train` and `x_test` are matrices, they have one predictor per row 

151 instead of per column. 

152 - If ``usequants=False``, R BART switches to quantiles anyway if there are 

153 less predictor values than the required number of bins, while bartz 

154 always follows the specification. 

155 - The error variance parameter is called `lamda` instead of `lambda`. 

156 - `rm_const` is always `False`. 

157 - The default `numcut` is 255 instead of 100. 

158 - A lot of functionality is missing (variable selection, discrete response). 

159 - There are some additional attributes, and some missing. 

160 

161 The linear regression used to set `sigest` adds an intercept. 

162 """ 

163 

164 def __init__(self, x_train, y_train, *, 1a

165 x_test=None, 

166 usequants=False, 

167 sigest=None, 

168 sigdf=3, 

169 sigquant=0.9, 

170 k=2, 

171 power=2, 

172 base=0.95, 

173 maxdepth=6, 

174 lamda=None, 

175 offset=None, 

176 ntree=200, 

177 numcut=255, 

178 ndpost=1000, 

179 nskip=100, 

180 keepevery=1, 

181 printevery=100, 

182 seed=0, 

183 initkw={}, 

184 ): 

185 

186 x_train, x_train_fmt = self._process_predictor_input(x_train) 1a

187 

188 y_train, y_train_fmt = self._process_response_input(y_train) 1a

189 self._check_same_length(x_train, y_train) 1a

190 

191 offset = self._process_offset_settings(y_train, offset) 1a

192 scale = self._process_scale_settings(y_train, k) 1a

193 lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset) 1a

194 

195 splits, max_split = self._determine_splits(x_train, usequants, numcut) 1a

196 x_train = self._bin_predictors(x_train, splits) 1a

197 

198 y_train = self._transform_input(y_train, offset, scale) 1a

199 lamda_scaled = lamda / (scale * scale) 1a

200 

201 mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree, initkw) 1a

202 final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed) 1a

203 

204 sigma = self._extract_sigma(main_trace, scale) 1a

205 first_sigma = self._extract_sigma(burnin_trace, scale) 1a

206 

207 self.offset = offset 1a

208 self.scale = scale 1a

209 self.lamda = lamda 1a

210 self.sigest = sigest 1a

211 self.ntree = ntree 1a

212 self.maxdepth = maxdepth 1a

213 self.sigma = sigma 1a

214 self.first_sigma = first_sigma 1a

215 

216 self._x_train_fmt = x_train_fmt 1a

217 self._splits = splits 1a

218 self._main_trace = main_trace 1a

219 self._mcmc_state = final_state 1a

220 

221 if x_test is not None: 1a

222 yhat_test = self.predict(x_test) 1a

223 self.yhat_test = yhat_test 1a

224 self.yhat_test_mean = yhat_test.mean(axis=0) 1a

225 

226 @functools.cached_property 1a

227 def yhat_train(self): 1a

228 x_train = self._mcmc_state['X'] 1a

229 yhat_train = self._predict(self._main_trace, x_train) 1a

230 return self._transform_output(yhat_train, self.offset, self.scale) 1a

231 

232 @functools.cached_property 1a

233 def yhat_train_mean(self): 1a

234 return self.yhat_train.mean(axis=0) 1a

235 

236 def predict(self, x_test): 1a

237 """ 

238 Compute the posterior mean at `x_test` for each MCMC iteration. 

239 

240 Parameters 

241 ---------- 

242 x_test : array (m, p) or DataFrame 

243 The test predictors. 

244 

245 Returns 

246 ------- 

247 yhat_test : array (ndpost, m) 

248 The conditional posterior mean at `x_test` for each MCMC iteration. 

249 """ 

250 x_test, x_test_fmt = self._process_predictor_input(x_test) 1a

251 self._check_compatible_formats(x_test_fmt, self._x_train_fmt) 1a

252 x_test = self._bin_predictors(x_test, self._splits) 1a

253 yhat_test = self._predict(self._main_trace, x_test) 1a

254 return self._transform_output(yhat_test, self.offset, self.scale) 1a

255 

256 @staticmethod 1a

257 def _process_predictor_input(x): 1a

258 if hasattr(x, 'columns'): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true1a

259 fmt = dict(kind='dataframe', columns=x.columns) 

260 x = x.to_numpy().T 

261 else: 

262 fmt = dict(kind='array', num_covar=x.shape[0]) 1a

263 x = jnp.asarray(x) 1a

264 assert x.ndim == 2 1a

265 return x, fmt 1a

266 

267 @staticmethod 1a

268 def _check_compatible_formats(fmt1, fmt2): 1a

269 assert fmt1 == fmt2 1a

270 

271 @staticmethod 1a

272 def _process_response_input(y): 1a

273 if hasattr(y, 'to_numpy'): 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true1a

274 fmt = dict(kind='series', name=y.name) 

275 y = y.to_numpy() 

276 else: 

277 fmt = dict(kind='array') 1a

278 y = jnp.asarray(y) 1a

279 assert y.ndim == 1 1a

280 return y, fmt 1a

281 

282 @staticmethod 1a

283 def _check_same_length(x1, x2): 1a

284 get_length = lambda x: x.shape[-1] 1a

285 assert get_length(x1) == get_length(x2) 1a

286 

287 @staticmethod 1a

288 def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset): 1a

289 if lamda is not None: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true1a

290 return lamda, None 

291 else: 

292 if sigest is not None: 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true1a

293 sigest2 = sigest * sigest 

294 elif y_train.size < 2: 1a

295 sigest2 = 1 1a

296 elif y_train.size <= x_train.shape[0]: 1a

297 sigest2 = jnp.var(y_train - offset) 1a

298 else: 

299 x_centered = x_train.T - x_train.mean(axis=1) 1a

300 y_centered = y_train - y_train.mean() 1a

301 # centering is equivalent to adding an intercept column 

302 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1a

303 chisq = chisq.squeeze(0) 1a

304 dof = len(y_train) - rank 1a

305 sigest2 = chisq / dof 1a

306 alpha = sigdf / 2 1a

307 invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2 1a

308 invchi2rid = invchi2 * sigdf 1a

309 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1a

310 

311 @staticmethod 1a

312 def _process_offset_settings(y_train, offset): 1a

313 if offset is not None: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true1a

314 return offset 

315 elif y_train.size < 1: 1a

316 return 0 1a

317 else: 

318 return y_train.mean() 1a

319 

320 @staticmethod 1a

321 def _process_scale_settings(y_train, k): 1a

322 if y_train.size < 2: 1a

323 return 1 1a

324 else: 

325 return (y_train.max() - y_train.min()) / (2 * k) 1a

326 

327 @staticmethod 1a

328 def _determine_splits(x_train, usequants, numcut): 1a

329 if usequants: 1a

330 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1a

331 else: 

332 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1a

333 

334 @staticmethod 1a

335 def _bin_predictors(x, splits): 1a

336 return prepcovars.bin_predictors(x, splits) 1a

337 

338 @staticmethod 1a

339 def _transform_input(y, offset, scale): 1a

340 return (y - offset) / scale 1a

341 

342 @staticmethod 1a

343 def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree, initkw): 1a

344 depth = jnp.arange(maxdepth - 1) 1a

345 p_nonterminal = base / (1 + depth).astype(float) ** power 1a

346 sigma2_alpha = sigdf / 2 1a

347 sigma2_beta = lamda * sigma2_alpha 1a

348 kw = dict( 1a

349 X=x_train, 

350 y=y_train, 

351 max_split=max_split, 

352 num_trees=ntree, 

353 p_nonterminal=p_nonterminal, 

354 sigma2_alpha=sigma2_alpha, 

355 sigma2_beta=sigma2_beta, 

356 min_points_per_leaf=5, 

357 ) 

358 kw.update(initkw) 1a

359 return mcmcstep.init(**kw) 1a

360 

361 @staticmethod 1a

362 def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed): 1a

363 if isinstance(seed, jax.Array) and jnp.issubdtype(seed.dtype, jax.dtypes.prng_key): 363 ↛ 366line 363 didn't jump to line 366 because the condition on line 363 was always true1a

364 key = seed 1a

365 else: 

366 key = jax.random.key(seed) 

367 callback = mcmcloop.make_simple_print_callback(printevery) 1a

368 return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback, key) 1a

369 

370 @staticmethod 1a

371 def _predict(trace, x): 1a

372 return mcmcloop.evaluate_trace(trace, x) 1a

373 

374 @staticmethod 1a

375 def _transform_output(y, offset, scale): 1a

376 return offset + scale * y 1a

377 

378 @staticmethod 1a

379 def _extract_sigma(trace, scale): 1a

380 return scale * jnp.sqrt(trace['sigma2']) 1a

381 

382 

383 def _show_tree(self, i_sample, i_tree, print_all=False): 1a

384 from . import debug 

385 trace = self._main_trace 

386 leaf_tree = trace['leaf_trees'][i_sample, i_tree] 

387 var_tree = trace['var_trees'][i_sample, i_tree] 

388 split_tree = trace['split_trees'][i_sample, i_tree] 

389 debug.print_tree(leaf_tree, var_tree, split_tree, print_all) 

390 

391 def _sigma_harmonic_mean(self, prior=False): 1a

392 bart = self._mcmc_state 

393 if prior: 

394 alpha = bart['sigma2_alpha'] 

395 beta = bart['sigma2_beta'] 

396 else: 

397 resid = bart['resid'] 

398 alpha = bart['sigma2_alpha'] + resid.size / 2 

399 norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype) 

400 beta = bart['sigma2_beta'] + norm2 / 2 

401 sigma2 = beta / alpha 

402 return jnp.sqrt(sigma2) * self.scale 

403 

404 def _compare_resid(self): 1a

405 bart = self._mcmc_state 1a

406 resid1 = bart['resid'] 1a

407 yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32) 1a

408 resid2 = bart['y'] - yhat 1a

409 return resid1, resid2 1a

410 

411 def _avg_acc(self): 1a

412 trace = self._main_trace 

413 def acc(prefix): 

414 acc = trace[f'{prefix}_acc_count'] 

415 prop = trace[f'{prefix}_prop_count'] 

416 return acc.sum() / prop.sum() 

417 return acc('grow'), acc('prune') 

418 

419 def _avg_prop(self): 1a

420 trace = self._main_trace 

421 def prop(prefix): 

422 return trace[f'{prefix}_prop_count'].sum() 

423 pgrow = prop('grow') 

424 pprune = prop('prune') 

425 total = pgrow + pprune 

426 return pgrow / total, pprune / total 

427 

428 def _avg_move(self): 1a

429 agrow, aprune = self._avg_acc() 

430 pgrow, pprune = self._avg_prop() 

431 return agrow * pgrow, aprune * pprune 

432 

433 def _depth_distr(self): 1a

434 from . import debug 

435 trace = self._main_trace 

436 split_trees = trace['split_trees'] 

437 return debug.trace_depth_distr(split_trees) 

438 

439 def _points_per_leaf_distr(self): 1a

440 from . import debug 1a

441 return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state['X']) 1a

442 

443 def _check_trees(self): 1a

444 from . import debug 1a

445 return debug.check_trace(self._main_trace, self._mcmc_state) 1a

446 

447 def _tree_goes_bad(self): 1a

448 bad = self._check_trees().astype(bool) 

449 bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)]) 

450 return bad & ~bad_before