Coverage for src/bartz/BART.py: 76%
205 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1# bartz/src/bartz/BART.py
2#
3# Copyright (c) 2024, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25import functools 1a
27import jax 1a
28import jax.numpy as jnp 1a
30from . import jaxext 1a
31from . import grove 1a
32from . import mcmcstep 1a
33from . import mcmcloop 1a
34from . import prepcovars 1a
36class gbart: 1a
37 """
38 Nonparametric regression with Bayesian Additive Regression Trees (BART).
40 Regress `y_train` on `x_train` with a latent mean function represented as
41 a sum of decision trees. The inference is carried out by sampling the
42 posterior distribution of the tree ensemble with an MCMC.
44 Parameters
45 ----------
46 x_train : array (p, n) or DataFrame
47 The training predictors.
48 y_train : array (n,) or Series
49 The training responses.
50 x_test : array (p, m) or DataFrame, optional
51 The test predictors.
52 usequants : bool, default False
53 Whether to use predictors quantiles instead of a uniform grid to bin
54 predictors.
55 sigest : float, optional
56 An estimate of the residual standard deviation on `y_train`, used to
57 set `lamda`. If not specified, it is estimated by linear regression.
58 If `y_train` has less than two elements, it is set to 1. If n <= p, it
59 is set to the variance of `y_train`. Ignored if `lamda` is specified.
60 sigdf : int, default 3
61 The degrees of freedom of the scaled inverse-chisquared prior on the
62 noise variance.
63 sigquant : float, default 0.9
64 The quantile of the prior on the noise variance that shall match
65 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
66 k : float, default 2
67 The inverse scale of the prior standard deviation on the latent mean
68 function, relative to half the observed range of `y_train`. If `y_train`
69 has less than two elements, `k` is ignored and the scale is set to 1.
70 power : float, default 2
71 base : float, default 0.95
72 Parameters of the prior on tree node generation. The probability that a
73 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
74 power``.
75 maxdepth : int, default 6
76 The maximum depth of the trees. This is 1-based, so with the default
77 ``maxdepth=6``, the depths of the levels range from 0 to 5.
78 lamda : float, optional
79 The scale of the prior on the noise variance. If ``lamda==1``, the
80 prior is an inverse chi-squared scaled to have harmonic mean 1. If
81 not specified, it is set based on `sigest` and `sigquant`.
82 offset : float, optional
83 The prior mean of the latent mean function. If not specified, it is set
84 to the mean of `y_train`. If `y_train` is empty, it is set to 0.
85 ntree : int, default 200
86 The number of trees used to represent the latent mean function.
87 numcut : int, default 255
88 If `usequants` is `False`: the exact number of cutpoints used to bin the
89 predictors, ranging between the minimum and maximum observed values
90 (excluded).
92 If `usequants` is `True`: the maximum number of cutpoints to use for
93 binning the predictors. Each predictor is binned such that its
94 distribution in `x_train` is approximately uniform across bins. The
95 number of bins is at most the number of unique values appearing in
96 `x_train`, or ``numcut + 1``.
98 Before running the algorithm, the predictors are compressed to the
99 smallest integer type that fits the bin indices, so `numcut` is best set
100 to the maximum value of an unsigned integer type.
101 ndpost : int, default 1000
102 The number of MCMC samples to save, after burn-in.
103 nskip : int, default 100
104 The number of initial MCMC samples to discard as burn-in.
105 keepevery : int, default 1
106 The thinning factor for the MCMC samples, after burn-in.
107 printevery : int, default 100
108 The number of iterations (including skipped ones) between each log.
109 seed : int or jax random key, default 0
110 The seed for the random number generator.
112 Attributes
113 ----------
114 yhat_train : array (ndpost, n)
115 The conditional posterior mean at `x_train` for each MCMC iteration.
116 yhat_train_mean : array (n,)
117 The marginal posterior mean at `x_train`.
118 yhat_test : array (ndpost, m)
119 The conditional posterior mean at `x_test` for each MCMC iteration.
120 yhat_test_mean : array (m,)
121 The marginal posterior mean at `x_test`.
122 sigma : array (ndpost,)
123 The standard deviation of the error.
124 first_sigma : array (nskip,)
125 The standard deviation of the error in the burn-in phase.
126 offset : float
127 The prior mean of the latent mean function.
128 scale : float
129 The prior standard deviation of the latent mean function.
130 lamda : float
131 The prior harmonic mean of the error variance.
132 sigest : float or None
133 The estimated standard deviation of the error used to set `lamda`.
134 ntree : int
135 The number of trees.
136 maxdepth : int
137 The maximum depth of the trees.
138 initkw : dict
139 Additional arguments passed to `mcmcstep.init`.
141 Methods
142 -------
143 predict
145 Notes
146 -----
147 This interface imitates the function ``gbart`` from the R package `BART
148 <https://cran.r-project.org/package=BART>`_, but with these differences:
150 - If `x_train` and `x_test` are matrices, they have one predictor per row
151 instead of per column.
152 - If ``usequants=False``, R BART switches to quantiles anyway if there are
153 less predictor values than the required number of bins, while bartz
154 always follows the specification.
155 - The error variance parameter is called `lamda` instead of `lambda`.
156 - `rm_const` is always `False`.
157 - The default `numcut` is 255 instead of 100.
158 - A lot of functionality is missing (variable selection, discrete response).
159 - There are some additional attributes, and some missing.
161 The linear regression used to set `sigest` adds an intercept.
162 """
164 def __init__(self, x_train, y_train, *, 1a
165 x_test=None,
166 usequants=False,
167 sigest=None,
168 sigdf=3,
169 sigquant=0.9,
170 k=2,
171 power=2,
172 base=0.95,
173 maxdepth=6,
174 lamda=None,
175 offset=None,
176 ntree=200,
177 numcut=255,
178 ndpost=1000,
179 nskip=100,
180 keepevery=1,
181 printevery=100,
182 seed=0,
183 initkw={},
184 ):
186 x_train, x_train_fmt = self._process_predictor_input(x_train) 1a
188 y_train, y_train_fmt = self._process_response_input(y_train) 1a
189 self._check_same_length(x_train, y_train) 1a
191 offset = self._process_offset_settings(y_train, offset) 1a
192 scale = self._process_scale_settings(y_train, k) 1a
193 lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset) 1a
195 splits, max_split = self._determine_splits(x_train, usequants, numcut) 1a
196 x_train = self._bin_predictors(x_train, splits) 1a
198 y_train = self._transform_input(y_train, offset, scale) 1a
199 lamda_scaled = lamda / (scale * scale) 1a
201 mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree, initkw) 1a
202 final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed) 1a
204 sigma = self._extract_sigma(main_trace, scale) 1a
205 first_sigma = self._extract_sigma(burnin_trace, scale) 1a
207 self.offset = offset 1a
208 self.scale = scale 1a
209 self.lamda = lamda 1a
210 self.sigest = sigest 1a
211 self.ntree = ntree 1a
212 self.maxdepth = maxdepth 1a
213 self.sigma = sigma 1a
214 self.first_sigma = first_sigma 1a
216 self._x_train_fmt = x_train_fmt 1a
217 self._splits = splits 1a
218 self._main_trace = main_trace 1a
219 self._mcmc_state = final_state 1a
221 if x_test is not None: 1a
222 yhat_test = self.predict(x_test) 1a
223 self.yhat_test = yhat_test 1a
224 self.yhat_test_mean = yhat_test.mean(axis=0) 1a
226 @functools.cached_property 1a
227 def yhat_train(self): 1a
228 x_train = self._mcmc_state['X'] 1a
229 yhat_train = self._predict(self._main_trace, x_train) 1a
230 return self._transform_output(yhat_train, self.offset, self.scale) 1a
232 @functools.cached_property 1a
233 def yhat_train_mean(self): 1a
234 return self.yhat_train.mean(axis=0) 1a
236 def predict(self, x_test): 1a
237 """
238 Compute the posterior mean at `x_test` for each MCMC iteration.
240 Parameters
241 ----------
242 x_test : array (m, p) or DataFrame
243 The test predictors.
245 Returns
246 -------
247 yhat_test : array (ndpost, m)
248 The conditional posterior mean at `x_test` for each MCMC iteration.
249 """
250 x_test, x_test_fmt = self._process_predictor_input(x_test) 1a
251 self._check_compatible_formats(x_test_fmt, self._x_train_fmt) 1a
252 x_test = self._bin_predictors(x_test, self._splits) 1a
253 yhat_test = self._predict(self._main_trace, x_test) 1a
254 return self._transform_output(yhat_test, self.offset, self.scale) 1a
256 @staticmethod 1a
257 def _process_predictor_input(x): 1a
258 if hasattr(x, 'columns'): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true1a
259 fmt = dict(kind='dataframe', columns=x.columns)
260 x = x.to_numpy().T
261 else:
262 fmt = dict(kind='array', num_covar=x.shape[0]) 1a
263 x = jnp.asarray(x) 1a
264 assert x.ndim == 2 1a
265 return x, fmt 1a
267 @staticmethod 1a
268 def _check_compatible_formats(fmt1, fmt2): 1a
269 assert fmt1 == fmt2 1a
271 @staticmethod 1a
272 def _process_response_input(y): 1a
273 if hasattr(y, 'to_numpy'): 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true1a
274 fmt = dict(kind='series', name=y.name)
275 y = y.to_numpy()
276 else:
277 fmt = dict(kind='array') 1a
278 y = jnp.asarray(y) 1a
279 assert y.ndim == 1 1a
280 return y, fmt 1a
282 @staticmethod 1a
283 def _check_same_length(x1, x2): 1a
284 get_length = lambda x: x.shape[-1] 1a
285 assert get_length(x1) == get_length(x2) 1a
287 @staticmethod 1a
288 def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset): 1a
289 if lamda is not None: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true1a
290 return lamda, None
291 else:
292 if sigest is not None: 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true1a
293 sigest2 = sigest * sigest
294 elif y_train.size < 2: 1a
295 sigest2 = 1 1a
296 elif y_train.size <= x_train.shape[0]: 1a
297 sigest2 = jnp.var(y_train - offset) 1a
298 else:
299 x_centered = x_train.T - x_train.mean(axis=1) 1a
300 y_centered = y_train - y_train.mean() 1a
301 # centering is equivalent to adding an intercept column
302 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1a
303 chisq = chisq.squeeze(0) 1a
304 dof = len(y_train) - rank 1a
305 sigest2 = chisq / dof 1a
306 alpha = sigdf / 2 1a
307 invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2 1a
308 invchi2rid = invchi2 * sigdf 1a
309 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1a
311 @staticmethod 1a
312 def _process_offset_settings(y_train, offset): 1a
313 if offset is not None: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true1a
314 return offset
315 elif y_train.size < 1: 1a
316 return 0 1a
317 else:
318 return y_train.mean() 1a
320 @staticmethod 1a
321 def _process_scale_settings(y_train, k): 1a
322 if y_train.size < 2: 1a
323 return 1 1a
324 else:
325 return (y_train.max() - y_train.min()) / (2 * k) 1a
327 @staticmethod 1a
328 def _determine_splits(x_train, usequants, numcut): 1a
329 if usequants: 1a
330 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1a
331 else:
332 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1a
334 @staticmethod 1a
335 def _bin_predictors(x, splits): 1a
336 return prepcovars.bin_predictors(x, splits) 1a
338 @staticmethod 1a
339 def _transform_input(y, offset, scale): 1a
340 return (y - offset) / scale 1a
342 @staticmethod 1a
343 def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree, initkw): 1a
344 depth = jnp.arange(maxdepth - 1) 1a
345 p_nonterminal = base / (1 + depth).astype(float) ** power 1a
346 sigma2_alpha = sigdf / 2 1a
347 sigma2_beta = lamda * sigma2_alpha 1a
348 kw = dict( 1a
349 X=x_train,
350 y=y_train,
351 max_split=max_split,
352 num_trees=ntree,
353 p_nonterminal=p_nonterminal,
354 sigma2_alpha=sigma2_alpha,
355 sigma2_beta=sigma2_beta,
356 min_points_per_leaf=5,
357 )
358 kw.update(initkw) 1a
359 return mcmcstep.init(**kw) 1a
361 @staticmethod 1a
362 def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed): 1a
363 if isinstance(seed, jax.Array) and jnp.issubdtype(seed.dtype, jax.dtypes.prng_key): 363 ↛ 366line 363 didn't jump to line 366 because the condition on line 363 was always true1a
364 key = seed 1a
365 else:
366 key = jax.random.key(seed)
367 callback = mcmcloop.make_simple_print_callback(printevery) 1a
368 return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback, key) 1a
370 @staticmethod 1a
371 def _predict(trace, x): 1a
372 return mcmcloop.evaluate_trace(trace, x) 1a
374 @staticmethod 1a
375 def _transform_output(y, offset, scale): 1a
376 return offset + scale * y 1a
378 @staticmethod 1a
379 def _extract_sigma(trace, scale): 1a
380 return scale * jnp.sqrt(trace['sigma2']) 1a
383 def _show_tree(self, i_sample, i_tree, print_all=False): 1a
384 from . import debug
385 trace = self._main_trace
386 leaf_tree = trace['leaf_trees'][i_sample, i_tree]
387 var_tree = trace['var_trees'][i_sample, i_tree]
388 split_tree = trace['split_trees'][i_sample, i_tree]
389 debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
391 def _sigma_harmonic_mean(self, prior=False): 1a
392 bart = self._mcmc_state
393 if prior:
394 alpha = bart['sigma2_alpha']
395 beta = bart['sigma2_beta']
396 else:
397 resid = bart['resid']
398 alpha = bart['sigma2_alpha'] + resid.size / 2
399 norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
400 beta = bart['sigma2_beta'] + norm2 / 2
401 sigma2 = beta / alpha
402 return jnp.sqrt(sigma2) * self.scale
404 def _compare_resid(self): 1a
405 bart = self._mcmc_state 1a
406 resid1 = bart['resid'] 1a
407 yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32) 1a
408 resid2 = bart['y'] - yhat 1a
409 return resid1, resid2 1a
411 def _avg_acc(self): 1a
412 trace = self._main_trace
413 def acc(prefix):
414 acc = trace[f'{prefix}_acc_count']
415 prop = trace[f'{prefix}_prop_count']
416 return acc.sum() / prop.sum()
417 return acc('grow'), acc('prune')
419 def _avg_prop(self): 1a
420 trace = self._main_trace
421 def prop(prefix):
422 return trace[f'{prefix}_prop_count'].sum()
423 pgrow = prop('grow')
424 pprune = prop('prune')
425 total = pgrow + pprune
426 return pgrow / total, pprune / total
428 def _avg_move(self): 1a
429 agrow, aprune = self._avg_acc()
430 pgrow, pprune = self._avg_prop()
431 return agrow * pgrow, aprune * pprune
433 def _depth_distr(self): 1a
434 from . import debug
435 trace = self._main_trace
436 split_trees = trace['split_trees']
437 return debug.trace_depth_distr(split_trees)
439 def _points_per_leaf_distr(self): 1a
440 from . import debug 1a
441 return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state['X']) 1a
443 def _check_trees(self): 1a
444 from . import debug 1a
445 return debug.check_trace(self._main_trace, self._mcmc_state) 1a
447 def _tree_goes_bad(self): 1a
448 bad = self._check_trees().astype(bool)
449 bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
450 return bad & ~bad_before