Coverage for src/lsqfitgp/bayestree/_bcf.py: 80%
359 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
1# lsqfitgp/bayestree/_bcf.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1fabcde
21import warnings 1fabcde
23import numpy 1fabcde
24from scipy import stats 1fabcde
25from jax import numpy as jnp 1fabcde
26import jax 1fabcde
27import gvar 1fabcde
29from .. import copula 1fabcde
30from .. import _kernels 1fabcde
31from .. import _fit 1fabcde
32from .. import _array 1fabcde
33from .. import _GP 1fabcde
34from .. import _fastraniter 1fabcde
35from .. import _jaxext 1fabcde
36from .. import _gvarext 1fabcde
37from .. import _utils 1fabcde
39# TODO add methods or options to do causal inference stuff, e.g., impute missing
40# outcomes, or ate, att, cate, catt, sate, satt. Remember that the effect may
41# also depend on aux. See bartCause, possibly copy its naming.
43def _recursive_cast(dtype, default, mapping): 1fabcde
44 if dtype in mapping: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1abcde
45 return mapping[dtype]
46 elif dtype.names is not None: 1abcde
47 return numpy.dtype([ 1abcde
48 (name, _recursive_cast(dtype[name], default, mapping))
49 for name in dtype.names
50 ])
51 elif dtype.subdtype is not None: 51 ↛ 53line 51 didn't jump to line 53 because the condition on line 51 was never true1abcde
52 # note: has names => does not have subdtype
53 return numpy.dtype((_recursive_cast(dtype.base, default, mapping), dtype.shape))
54 elif default is None: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true1abcde
55 return dtype
56 else:
57 return default 1abcde
59def cast(dtype, default, mapping={}): 1fabcde
60 """
61 Recursively cast a numpy data type.
63 Parameters
64 ----------
65 dtype : dtype
66 The data type to cast.
67 default : dtype or None
68 The leaf fields of `dtype` are casted to `default`, which can be
69 structured, unless they appear in `mapping`. If None, dtypes not in
70 `mapping` are left unchanged.
71 mapping : dict
72 A dictionary from dtypes to dtypes, indicating specific casting rules.
73 The dtypes can be structured, a match of a structured dtype takes
74 precedence over matches in its leaves, and the converted dtype is not
75 further searched for matches.
77 Returns
78 -------
79 casted_dtype : dtype
80 The casted version of `dtype`. May not have the same structure if
81 `mapping` contains structured dtypes.
82 """
83 mapping = {numpy.dtype(k): numpy.dtype(v) for k, v in mapping.items()} 1abcde
84 default = None if default is None else numpy.dtype(default) 1abcde
85 return _recursive_cast(numpy.dtype(dtype), default, mapping) 1abcde
87 # TODO
88 # - move this to generic utils
89 # - make unit tests
91class bcf: 1fabcde
93 def __init__(self, *, 1fabcde
94 y,
95 z,
96 x_mu,
97 x_tau=None,
98 pihat,
99 include_pi='mu',
100 weights=None,
101 fitkw={},
102 kernelkw_mu={},
103 kernelkw_tau={},
104 marginalize_mean=True,
105 gpaux=None,
106 x_aux=None,
107 otherhp={},
108 transf='standardize',
109 ):
110 r"""
111 Nonparametric Bayesian regression with a GP version of BCF.
113 BCF (Bayesian Causal Forests) is a regression method for observational
114 causal inference studies introduced in [1]_ based on a pair of BART
115 models.
117 This class evaluates a Gaussian process regression with a kernel which
118 accurately approximates BCF in the infinite trees limit of each BART
119 model. The hyperparameters are optimized to their marginal MAP.
121 The model is (loosely, see notes below) :math:`y = \mu(x) + z\tau(x)`,
122 so :math:`\tau(x)` is the expected causal effect of :math:`z` on
123 :math:`y` at location :math:`x`.
125 Parameters
126 ----------
127 y : (n,) array, series or dataframe
128 Outcome.
129 z : (n,) array, series or dataframe
130 Binary treatment status: 0 control group, 1 treatment group.
131 x_mu : (n, p) array, series or dataframe
132 Covariates for the :math:`\mu` model.
133 x_tau : (n, q) array, series or dataframe, optional
134 Covariates for the :math:`\tau` model. If not specified, use `x_mu`.
135 pihat : (n,) array, series or dataframe
136 Estimated propensity score, i.e., P(Z=1|X).
137 include_pi : {'mu', 'tau', 'both'}, optional
138 Whether to include the propensity score in the :math:`\mu` model,
139 the :math:`\tau` model, or both. Default is ``'mu'``.
140 weights : (n,) array, series or dataframe
141 Weights used to rescale the error variance (as 1 / weight).
142 fitkw : dict
143 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides
144 the defaults.
145 kernelkw_mu, kernelkw_tau : dict
146 Additional arguments passed to `~lsqfitgp.BART` for each model,
147 overrides the defaults.
148 marginalize_mean : bool
149 If True (default), marginalize the intercept of the model.
150 gpaux : callable, optional
151 If specified, this function is called with a pair ``(hp, gp)``,
152 where ``hp`` is a dictionary of hyperparameters, and ``gp`` is a
153 `~lsqfitgp.GP` object under construction, and is expected to return
154 a modified ``gp`` with a new process named ``'aux'`` defined with
155 `~lsqfitgp.GP.defproc` or similar. The process is added to the
156 regression model. The input to the process is a structured array
157 with fields:
159 'train' : bool
160 Indicates whether the data is training set (the one passed on
161 initialization) or test set (the one passed to `pred` or `gp`).
162 'i' : int
163 Index of the flattened array.
164 'z' : int
165 Treatment status.
166 'mu', 'tau' : structured
167 The values in `x_mu` and `x_tau`, converted to indices according
168 to the BART grids. Where `pihat` has been added, there are two
169 subfields: ``'x'`` which contains the covariates, and
170 ``'pihat'``, the latter expressed in indices as well.
171 'pihat' : float
172 The `pihat` argument. Contrary to the subfield included under
173 ``'mu'`` and/or ``'tau'``, this field contains the original
174 values.
175 'aux' : structured
176 The values in `x_aux`, if specified.
178 x_aux : (n, k) array, series or dataframe, optional
179 Additional covariates for the ``'aux'`` process.
180 otherhp : dictionary of gvar
181 A dictionary with the prior of arbitrary additional hyperpameters,
182 intended to be used by ``gpaux`` or ``transf``.
183 transf : (list of) str or pair of callable
184 Data transformation. Either a string indicating a pre-defined
185 transformation, or a pair ``(from_data, to_data)``, two functions
186 with signatures ``from_data(hp, y) -> eta`` and ``to_data(hp, eta)
187 -> y``, where ``eta`` is the value to which the model is fit, and
188 ``hp`` is the dictionary of hyperparameters. The functions must be
189 ufuncs and one the inverse of the other w.r.t. the second parameter.
190 ``from_data`` must be derivable with `jax` w.r.t. ``y``.
192 If a list of such specifications is provided, the transformations
193 are applied in order, with the first one being the outermost, i.e.,
194 the one applied first to the data.
196 If a transformation uses additional hyperparameters, either
197 predefined automatically or passed by the user through `otherhp`,
198 they are inferred with the rest of the hyperparameters.
200 The pre-defined transformations are:
202 'standardize' (default)
203 eta = (y - mean(train_y)) / sdev(train_y)
204 'yeojohnson'
205 The Yeo-Johnson transformation [2]_ to reduce skewness. The
206 :math:`\lambda` parameter is bounded in :math:`(0, 2)`
207 for implementation convenience, this restriction may be lifted
208 in future versions.
210 Notes
211 -----
212 The regression model is:
214 .. math::
215 \eta_i = g(y_i; \ldots) &= m + {} \\
216 &\phantom{{}={}} +
217 \lambda_\mu
218 \mu(\mathbf x^\mu_i, \hat\pi_i?) + {} \\
219 &\phantom{{}={}} +
220 \lambda_\tau
221 \tau(\mathbf x^\tau_i, \hat\pi_i?) (z_i - z_0) + {} \\
222 &\phantom{{}={}} +
223 \mathrm{aux}(i, z_i, \mathbf x^\mu_i, \mathbf x^\tau_i,
224 \hat\pi_i, \mathbf x^\text{aux}_i) + {} \\
225 &\phantom{{}={}} +
226 \varepsilon_i, \\
227 \varepsilon_i &\sim
228 N(0, \sigma^2 / w_i), \\
229 m &\sim N(0, 1), \\
230 \log \sigma^2 &\sim N(\log\bar w, 4), \\
231 \lambda_\mu
232 &\sim \mathrm{HalfCauchy}(2), \\
233 \lambda_\tau
234 &\sim \mathrm{HalfNormal}(1.48), \\
235 \mu &\sim \mathrm{GP}(0,
236 \mathrm{BART}(\alpha_\mu, \beta_\mu) ), \\
237 \tau &\sim \mathrm{GP}(0,
238 \mathrm{BART}(\alpha_\tau, \beta_\tau) ), \\
239 \mathrm{aux} & \sim \mathrm{GP}(0, \text{<user defined>}), \\
240 \alpha_\mu, \alpha_\tau &\sim \mathrm{Beta}(2, 1), \\
241 \beta_\mu, \beta_\tau &\sim \mathrm{InvGamma}(1, 1), \\
242 z_0 &\sim U(0, 1),
244 To make the inference, :math:`(\mu, \tau, \boldsymbol\varepsilon, m,
245 \mathrm{aux})` are marginalized analytically, and the marginal posterior
246 mode of
247 :math:`(\sigma, \lambda_*, \alpha_*, \beta_*, z_0, \ldots)` is found by
248 numerical minimization, after transforming them to express their prior
249 as a Gaussian copula. Their marginal posterior covariance matrix is
250 estimated with an approximation of the hessian inverse. See
251 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize
252 this procedure.
254 The tree splitting grid of the BART kernel is set using quantiles of the
255 observed covariates. This corresponds to settings ``usequants=True``,
256 ``numcut=inf`` in the R packages BayesTree and BART. Use the parameters
257 `kernelkw_mu` and `kernelkw_tau` to customize the grids.
259 The difference between the regression model evaluated at :math:`Z=1` vs.
260 :math:`Z=0` can be interpreted as the causal effect :math:`Z \rightarrow
261 Y` if the unconfoundedness assumption is made:
263 .. math::
264 \{Y(Z=0), Y(Z=1)\} \perp\!\!\!\perp Z \mid X.
266 In practical terms, this holds when:
268 1) :math:`X` are pre-treatment variables, i.e., they represent
269 quantities causally upstream of :math:`Z`.
271 2) :math:`X` are sufficient to adjust for all common causes of
272 :math:`Z` and :math:`Y`, such that the only remaining difference
273 is the causal effect and not just a correlation.
275 Here :math:`X` consists in `x_tau`, `x_mu` and `x_aux`. However these
276 arrays may also be used to pass "technical" values used to set up the
277 model, that do not satisfy the uncounfoundedness assumption, if you know
278 what you are doing.
280 Attributes
281 ----------
282 m : float or gvar
283 The prior mean :math:`m`.
284 sigma : gvar
285 The error term standard deviation :math:`\sigma`. If there are
286 weights, the sdev for each unit is obtained dividing ``sigma`` by
287 sqrt(weight).
288 alpha_mu, alpha_tau : gvar
289 The numerator of the tree spawn probability :math:`\alpha_*` (named
290 ``base`` in R bcf).
291 beta_mu, beta_tau : gvar
292 The depth exponent of the tree spawn probability :math:`\beta_*`
293 (named ``power`` in R bcf).
294 lambda_mu, lambda_tau : gvar
295 The prior standard deviation :math:`\lambda_*`.
296 z_0 : gvar
297 The treatment coding parameter.
298 fit : empbayes_fit
299 The hyperparameters fit object.
301 Methods
302 -------
303 gp :
304 Create a GP object.
305 data :
306 Creates the dictionary to be passed to `GP.pred` to represent data.
307 pred :
308 Evaluate the regression function at given locations.
309 from_data :
310 Convert :math:`y` to :math:`\eta`.
311 to_data :
312 Convert :math:`\eta` to :math:`y`.
314 See also
315 --------
316 lsqfitgp.BART
318 References
319 ----------
320 .. [1] P. Richard Hahn, Jared S. Murray, Carlos M. Carvalho "Bayesian
321 Regression Tree Models for Causal Inference: Regularization,
322 Confounding, and Heterogeneous Effects (with Discussion)," Bayesian
323 Analysis 15(3), 965-1056, September 2020,
324 https://doi.org/10.1214/19-BA1195
325 .. [2] Yeo, In-Kwon; Johnson, Richard A. (2000). "A New Family of Power
326 Transformations to Improve Normality or Symmetry". Biometrika. 87
327 (4): 954–959. https://doi.org/10.1093/biomet/87.4.954
328 """
330 # convert covariates to StructuredArray
331 x_mu = self._to_structured(x_mu) 1abcde
332 if x_tau is not None: 1abcde
333 x_tau = self._to_structured(x_tau) 1bcde
334 assert x_tau.shape == x_mu.shape 1bcde
335 if x_aux is not None: 1abcde
336 x_aux = self._to_structured(x_aux) 1bcde
337 assert x_aux.shape == x_mu.shape 1bcde
339 # convert outcomes, treatment, propensity score, weights to 1d arrays
340 y = self._to_vector(y) 1abcde
341 z = self._to_vector(z) 1abcde
342 pihat = self._to_vector(pihat) 1abcde
343 assert y.shape == z.shape == pihat.shape == x_mu.shape 1abcde
344 if weights is not None: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true1abcde
345 weights = self._to_vector(weights)
346 assert weights.shape == x_mu.shape
348 # check include_pi
349 if include_pi not in ('mu', 'tau', 'both'): 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true1abcde
350 raise KeyError(f'invalid value include_pi={include_pi!r}')
351 self._include_pi = include_pi 1abcde
353 # add pihat to covariates
354 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1abcde
356 # grid and indices
357 splits_mu = _kernels.BART.splits_from_coord(x_mu) 1abcde
358 i_mu = self._toindices(x_mu, splits_mu) 1abcde
359 if x_tau is None: 1abcde
360 splits_tau = splits_mu 1abcde
361 i_tau = None 1abcde
362 else:
363 splits_tau = _kernels.BART.splits_from_coord(x_tau) 1bcde
364 i_tau = self._toindices(x_tau, splits_tau) 1bcde
366 # get functions for data transformation
367 from_data, to_data, transfloss, transfhp = self._get_transf( 1abcde
368 transf=transf, weights=weights, y=y)
370 # scale of error variance
371 logsigma2_loc = 0 if weights is None else numpy.log(jnp.mean(weights)) 1abcde
373 # prior on hyperparams
374 hyperprior = copula.makedict({ 1abcde
375 'm': gvar.gvar(0, 1),
376 'sigma^2': copula.lognorm(logsigma2_loc, 2),
377 'lambda_mu': copula.halfcauchy(2),
378 'lambda_tau': copula.halfnorm(1.48),
379 'alpha_mu': copula.beta(2, 1),
380 'alpha_tau': copula.beta(2, 1),
381 'beta_mu': copula.invgamma(1, 1),
382 'beta_tau': copula.invgamma(1, 1),
383 'z_0': copula.uniform(0, 1),
384 })
386 # remove explicit mean parameter if it's baked into the Gaussian process
387 if marginalize_mean: 1abcde
388 hyperprior.pop('m') 1abcde
390 # add data transformation and user hyperparameters
391 def update_hyperparams(new, newname, raises): 1abcde
392 new = gvar.BufferDict(new) 1abcde
393 for key in new.all_keys(): 1abcde
394 if hyperprior.has_dictkey(key): 394 ↛ 395line 394 didn't jump to line 395 because the condition on line 394 was never true1abcde
395 message = f'{newname} hyperparameter {key!r} overrides existing one'
396 if raises:
397 raise ValueError(message)
398 else:
399 warnings.warn(message)
400 hyperprior.update(new) 1abcde
401 update_hyperparams(transfhp, 'data transformation', True) 1abcde
402 # the hypers handed by _get_transf are not allowed to override
403 update_hyperparams(otherhp, 'user', False) 1abcde
405 # GP factory
406 def gpfactory(hp, *, z, i_mu, i_tau, pihat, x_aux, weights, 1abcde
407 splits_mu, splits_tau, **_):
409 # TODO maybe I should pass kernelkw_* as arguments, but they may not
410 # be jittable. I need jitkw in empbayes_fit for that.
412 kw_overridable = dict( 1abcde
413 maxd=10,
414 reset=[2, 4, 6, 8],
415 intercept=False,
416 )
417 kw_not_overridable = dict(indices=True) 1abcde
419 gp = _GP.GP(checkpos=False, checksym=False, solver='chol') 1abcde
421 for name, kernelkw in dict(mu=kernelkw_mu, tau=kernelkw_tau).items(): 1abcde
422 kw = dict( 1abcde
423 alpha=hp[f'alpha_{name}'],
424 beta=hp[f'beta_{name}'],
425 dim=name,
426 splits=eval(f'splits_{name}'),
427 **kw_overridable,
428 )
429 kw.update(kernelkw) 1abcde
430 kernel = _kernels.BART(**kw, **kw_not_overridable) 1abcde
431 kernel *= hp[f'lambda_{name}'] ** 2 1abcde
433 gp = gp.defproc(name, kernel) 1abcde
435 if 'm' in hp: 1abcde
436 kernel_mean = 0 * _kernels.Constant() 1bcde
437 else:
438 kernel_mean = _kernels.Constant() 1abcde
439 gp = gp.defproc('m', kernel_mean) 1abcde
441 if gpaux is None: 1abcde
442 gp = gp.defproc('aux', 0 * _kernels.Constant()) 1abcde
443 else:
444 gp = gpaux(hp, gp) 1bcde
446 gp = gp.deflintransf( 1abcde
447 gp.DefaultProcess,
448 lambda m, mu, tau, aux: lambda x:
449 m(x) + mu(x) + tau(x) * (x['z'] - hp['z_0']) + aux(x),
450 ['m', 'mu', 'tau', 'aux'],
451 )
453 x = self._join_points(True, z, i_mu, i_tau, pihat, x_aux) 1abcde
454 gp = gp.addx(x, 'trainmean') 1abcde
455 errcov = self._error_cov(hp, weights, x) 1abcde
456 return (gp 1abcde
457 .addcov(errcov, 'trainnoise')
458 .addtransf({'trainmean': 1, 'trainnoise': 1}, 'train')
459 )
461 # data factory
462 def data(hp, *, y, **_): 1abcde
463 return {'train': from_data(hp, y) - hp.get('m', 0)} 1abcde
465 # fit hyperparameters
466 options = dict( 1abcde
467 verbosity=3,
468 minkw=dict(
469 method='l-bfgs-b',
470 options=dict(
471 maxls=4,
472 maxiter=100,
473 ),
474 ),
475 mlkw=dict(
476 epsrel=0,
477 ),
478 forward=True,
479 gpfactorykw=dict(
480 y=y,
481 z=z,
482 i_mu=i_mu,
483 i_tau=i_tau,
484 pihat=pihat,
485 x_aux=x_aux,
486 weights=weights,
487 splits_mu=splits_mu,
488 splits_tau=splits_tau,
489 ),
490 additional_loss=transfloss,
491 )
492 options.update(fitkw) 1abcde
493 fit = _fit.empbayes_fit(hyperprior, gpfactory, data, **options) 1abcde
495 # extract hyperparameters from minimization result
496 self.m = fit.p.get('m', 0) 1abcde
497 self.sigma = gvar.sqrt(fit.p['sigma^2']) 1abcde
498 self.lambda_mu = fit.p['lambda_mu'] 1abcde
499 self.lambda_tau = fit.p['lambda_tau'] 1abcde
500 self.alpha_mu = fit.p['alpha_mu'] 1abcde
501 self.alpha_tau = fit.p['alpha_tau'] 1abcde
502 self.beta_mu = fit.p['beta_mu'] 1abcde
503 self.beta_tau = fit.p['beta_tau'] 1abcde
504 self.z_0 = fit.p['z_0'] 1abcde
506 # save other attributes
507 self.fit = fit 1abcde
508 self._from_data = from_data 1abcde
509 self._to_data = to_data 1abcde
511 def _append_pihat(self, x_mu, x_tau, pihat): 1fabcde
512 ip = self._include_pi 1abcde
513 if ip == 'mu' or ip == 'both': 1abcde
514 x_mu = _array.StructuredArray.from_dict(dict( 1abcde
515 x=x_mu,
516 pihat=pihat,
517 ))
518 if x_tau is not None and (ip == 'tau' or ip == 'both'): 1abcde
519 x_tau = _array.StructuredArray.from_dict(dict( 1bcde
520 x=x_tau,
521 pihat=pihat,
522 ))
523 return x_mu, x_tau 1abcde
525 @staticmethod 1fabcde
526 def _join_points(train, z, i_mu, i_tau, pihat, x_aux): 1fabcde
527 """ join covariates into a single StructuredArray """
528 columns = dict( 1abcde
529 train=jnp.broadcast_to(bool(train), z.shape),
530 i=jnp.arange(z.size).reshape(z.shape),
531 z=z,
532 mu=i_mu,
533 tau=i_mu if i_tau is None else i_tau,
534 pihat=pihat,
535 )
536 if x_aux is not None: 1abcde
537 columns.update(aux=x_aux) 1bcde
538 return _array.StructuredArray.from_dict(columns) 1abcde
540 @staticmethod 1fabcde
541 def _error_cov(hp, weights, x): 1fabcde
542 """ fill error covariance matrix """
543 if weights is None: 543 ↛ 546line 543 didn't jump to line 546 because the condition on line 543 was always true1abcde
544 error_var = jnp.broadcast_to(hp['sigma^2'], len(x)) 1abcde
545 else:
546 error_var = hp['sigma^2'] / weights
547 return jnp.diag(error_var) 1abcde
549 def _gethp(self, hp, rng): 1fabcde
550 if not isinstance(hp, str): 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true1abcde
551 return hp
552 elif hp == 'map': 1abcde
553 return self.fit.pmean 1abcde
554 elif hp == 'sample': 554 ↛ 557line 554 didn't jump to line 557 because the condition on line 554 was always true1abcde
555 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng) 1abcde
556 else:
557 raise KeyError(hp)
559 def gp(self, *, hp='map', z=None, x_mu=None, x_tau=None, pihat=None, 1fabcde
560 x_aux=None, weights=None, rng=None):
561 """
562 Create a Gaussian process with the fitted hyperparameters.
564 Parameters
565 ----------
566 hp : str or dict
567 The hyperparameters to use. If ``'map'`` (default), use the marginal
568 maximum a posteriori. If ``'sample'``, sample hyperparameters from
569 the posterior. If a dict, use the given hyperparameters.
570 z : (m,) array, series or dataframe, optional
571 Treatment status at test points. If specified, also `x_mu`, `pihat`,
572 `x_tau` and `x_aux` (the latter two if and only also specified at
573 initialization) must be specified.
574 x_mu : (m, p) array, series or dataframe, optional
575 Control model covariates at test points.
576 x_tau : (m, q) array, series or dataframe, optional
577 Moderating model covariates at test points.
578 pihat : (m,) array, series or dataframe, optional
579 Estimated propensity score at test points.
580 x_aux : (m, k) array, series or dataframe, optional
581 Additional covariates for the ``'aux'`` process.
582 weights : (m,) array, series or dataframe, optional
583 Weights for the error variance on the test points.
584 rng : numpy.random.Generator, optional
585 Random number generator, used if ``hp == 'sample'``.
587 Returns
588 -------
589 gp : GP
590 A centered Gaussian process object. To add the mean, use the `m`
591 attribute of the `bcf` object. The keys of the GP are ``'@mean'``,
592 ``'@noise'``, and ``'@'``, where the "@" stands either for 'train'
593 or 'test', and @ = @mean + @noise.
595 This Gaussian process is defined on the transformed data ``eta``.
596 """
598 hp = self._gethp(hp, rng)
599 return self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw)
601 def _gp(self, hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw): 1fabcde
602 """
603 Internal function to create the GP object. This function must work
604 both if the arguments are user-provided and need to be checked and
605 converted to standard format, or if they are traced jax values.
606 """
608 # create GP object
609 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde
611 # add test points
612 if z is not None: 1abcde
614 # check presence/absence of arguments is coherent
615 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1a
617 # check treatment and propensity score
618 z = self._to_vector(z) 1a
619 pihat = self._to_vector(pihat) 1a
620 assert pihat.shape == z.shape 1a
622 # check weights
623 if weights is not None: 623 ↛ 624line 623 didn't jump to line 624 because the condition on line 623 was never true1a
624 weights = self._to_vector(weights)
625 assert weights.shape == z.shape
627 # add propensity score to covariates
628 x_mu = self._to_structured(x_mu) 1a
629 assert x_mu.shape == z.shape 1a
630 if x_tau is not None: 630 ↛ 631line 630 didn't jump to line 631 because the condition on line 630 was never true1a
631 x_tau = self._to_structured(x_tau)
632 assert x_tau.shape == z.shape
633 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1a
635 # convert covariates to indices
636 i_mu = self._toindices(x_mu, gpfactorykw['splits_mu']) 1a
637 assert i_mu.dtype == gpfactorykw['i_mu'].dtype 1a
638 if x_tau is not None: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true1a
639 i_tau = self._toindices(x_tau, gpfactorykw['splits_tau'])
640 assert i_tau.dtype == gpfactorykw['i_tau'].dtype
641 else:
642 i_tau = None 1a
644 # check auxiliary points
645 if x_aux is not None: 645 ↛ 646line 645 didn't jump to line 646 because the condition on line 645 was never true1a
646 x_aux = self._to_structured(x_aux)
648 # add test points
649 x = self._join_points(False, z, i_mu, i_tau, pihat, x_aux) 1a
650 gp = gp.addx(x, 'testmean') 1a
651 errcov = self._error_cov(hp, weights, x) 1a
652 gp = (gp 1a
653 .addcov(errcov, 'testnoise')
654 .addtransf({'testmean': 1, 'testnoise': 1}, 'test')
655 )
657 return gp 1abcde
659 def _check_coherent_covariates(self, z, x_mu, x_tau, pihat, x_aux): 1fabcde
660 if z is None: 1abcde
661 assert x_mu is None 1bcde
662 assert x_tau is None 1bcde
663 assert pihat is None 1bcde
664 assert x_aux is None 1bcde
665 else:
666 assert x_mu is not None 1a
667 assert pihat is not None 1a
668 train_tau = self.fit.gpfactorykw['i_tau'] 1a
669 if x_tau is None: 669 ↛ 672line 669 didn't jump to line 672 because the condition on line 669 was always true1a
670 assert train_tau is None 1a
671 else:
672 assert train_tau is not None
673 train_aux = self.fit.gpfactorykw['x_aux'] 1a
674 if x_aux is None: 674 ↛ 677line 674 didn't jump to line 677 because the condition on line 674 was always true1a
675 assert train_aux is None 1a
676 else:
677 assert train_aux is not None
679 def data(self, *, hp='map', rng=None): 1fabcde
680 """
681 Get the data to be passed to `GP.pred` on a GP object returned by `gp`.
683 Parameters
684 ----------
685 hp : str or dict
686 The hyperparameters to use. If ``'map'`` (default), use the marginal
687 maximum a posteriori. If ``'sample'``, sample hyperparameters from
688 the posterior. If a dict, use the given hyperparameters.
689 rng : numpy.random.Generator, optional
690 Random number generator, used if ``hp == 'sample'``.
692 Returns
693 -------
694 data : dict
695 A dictionary representing ``eta`` in the format required by the
696 `GP.pred` method.
697 """
699 hp = self._gethp(hp, rng)
700 return self.fit.data(hp, **self.fit.gpfactorykw)
702 def pred(self, *, hp='map', error=False, z=None, x_mu=None, x_tau=None, 1fabcde
703 pihat=None, x_aux=None, weights=None, transformed=True, samples=None,
704 gvars=False, rng=None):
705 r"""
706 Predict the transformed outcome at given locations.
708 Parameters
709 ----------
710 hp : str or dict
711 The hyperparameters to use. If ``'map'`` (default), use the marginal
712 maximum a posteriori. If ``'sample'``, sample hyperparameters from
713 the posterior. If a dict, use the given hyperparameters.
714 error : bool, default False
715 If ``False``, make a prediction for the latent mean. If ``True``,
716 add the error term.
717 z : (m,) array, series or dataframe, optional
718 Treatment status at test points. If specified, also `x_mu`, `pihat`,
719 `x_tau` and `x_aux` (the latter two if and only also specified at
720 initialization) must be specified.
721 x_mu : (m, p) array, series or dataframe, optional
722 :math:`\mu` model covariates at test points.
723 x_tau : (m, q) array, series or dataframe, optional
724 :math:`\tau` model covariates at test points.
725 pihat : (m,) array, series or dataframe, optional
726 Estimated propensity score at test points.
727 x_aux : (m, k) array, series or dataframe, optional
728 Additional covariates for the ``'aux'`` process at test points.
729 weights : (m,) array, series or dataframe, optional
730 Weights for the error variance on the test points.
731 transformed : bool, default True
732 If ``True``, return the prediction on the transformed outcome
733 :math:`\eta`, else the observable outcome :math:`y`.
734 samples : int, optional
735 If specified, indicates the number of samples to take from the
736 posterior. If not, return the mean and covariance matrix of the
737 posterior.
738 gvars : bool, default False
739 If ``True``, return the mean and covariance matrix of the posterior
740 as an array of `GVar` variables.
741 rng : numpy.random.Generator, optional
742 Random number generator, used if ``hp == 'sample'`` or ``samples``
743 is not `None`.
745 Returns
746 -------
747 If ``samples`` is `None` and ``gvars`` is `False` (default):
749 mean, cov : (m,) and (m, m) arrays
750 The mean and covariance matrix of the Normal posterior distribution
751 over the regression function or :math:`\eta` at the specified
752 locations.
754 If ``samples`` is `None` and ``gvars`` is `True`:
756 out : (m,) array of gvars
757 The same distribution represented as an array of `~gvar.GVar`
758 objects.
760 If ``samples`` is an integer:
762 sample : (samples, m) array
763 Posterior samples over either the regression function, :math:`\eta`,
764 or :math:`y`.
765 """
767 # check consistency of output choice
768 if samples is None: 768 ↛ 769line 768 didn't jump to line 769 because the condition on line 768 was never true1abcde
769 if not transformed:
770 raise ValueError('Posterior is required in analytical form '
771 '(samples=None) and in data space '
772 '(transformed=False), this is not possible as '
773 'the transformation model space -> data space '
774 'is arbitrary. Either sample the posterior or '
775 'get the result in model space.')
776 else:
777 if not transformed and not error: 777 ↛ 778line 777 didn't jump to line 778 because the condition on line 777 was never true1abcde
778 raise ValueError('Posterior is required in data space '
779 '(transformed=False) and without error term '
780 '(error=False), this is not possible as the '
781 'transformation model space -> data space '
782 'applies after adding the error.')
783 assert not gvars, 'can not represent posterior samples as gvars' 1abcde
785 # TODO allow exceptions to these rules when there are no transformations
786 # or the only transformation is 'standardize'.
788 # get hyperparameters
789 hp = self._gethp(hp, rng) 1abcde
791 # check presence of covariates is coherent
792 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1abcde
794 # convert all inputs to arrays compatible with jax to pass them to the
795 # compiled implementation
796 if z is not None: 1abcde
797 z = self._to_vector(z) 1a
798 pihat = self._to_vector(pihat) 1a
799 x_mu = self._to_structured(x_mu) 1a
800 if x_tau is not None: 800 ↛ 801line 800 didn't jump to line 801 because the condition on line 800 was never true1a
801 x_tau = self._to_structured(x_tau)
802 if x_aux is not None: 802 ↛ 803line 802 didn't jump to line 803 because the condition on line 802 was never true1a
803 x_aux = self._to_structured(x_aux)
804 if weights is not None: 804 ↛ 805line 804 didn't jump to line 805 because the condition on line 804 was never true1abcde
805 weights = self._to_vector(weights)
807 # GP regression
808 mean, cov = self._pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw, bool(error)) 1abcde
810 # return Normal posterior moments
811 if samples is None: 811 ↛ 812line 811 didn't jump to line 812 because the condition on line 811 was never true1abcde
812 if gvars:
813 return gvar.gvar(mean, cov, fast=True)
814 else:
815 return mean, cov
817 # sample from posterior
818 sample = jnp.stack(list(_fastraniter.raniter(mean, cov, n=samples, rng=rng))) 1abcde
819 # TODO when I add vectorized sampling, use it here
820 if not transformed: 820 ↛ 822line 820 didn't jump to line 822 because the condition on line 820 was always true1abcde
821 sample = self._to_data(hp, sample) 1abcde
822 return sample 1abcde
824 # TODO the default should be something in data space, so with samples.
825 # If I handle the analyitical posterior through standardize, I could
826 # also make it without samples by default. Although I guess for
827 # whatever calculations samples are more convenient (just do the
828 # calculation on the samples.)
830 @functools.cached_property 1fabcde
831 def _pred(self): 1fabcde
833 @functools.partial(jax.jit, static_argnums=(8,)) 1abcde
834 def _pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw, error): 1abcde
835 gp = self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw) 1abcde
836 data = self.fit.data(hp, **gpfactorykw) 1abcde
837 if z is None: 1abcde
838 label = 'train' 1bcde
839 else:
840 label = 'test' 1a
841 if not error: 841 ↛ 842line 841 didn't jump to line 842 because the condition on line 841 was never true1abcde
842 label += 'mean'
843 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde
844 return outmean + hp.get('m', 0), outcov 1abcde
846 # TODO make everything pure and jit this per class instead of per
847 # instance
849 return _pred 1abcde
851 def from_data(self, y, *, hp='map', rng=None): 1fabcde
852 """
853 Transforms outcomes :math:`y` to the regression variable :math:`\\eta`.
855 Parameters
856 ----------
857 y : (n,) array
858 Outcomes.
859 hp : str or dict
860 The hyperparameters to use. If ``'map'`` (default), use the marginal
861 maximum a posteriori. If ``'sample'``, sample hyperparameters from
862 the posterior. If a dict, use the given hyperparameters.
863 rng : numpy.random.Generator, optional
864 Random number generator, used if ``hp == 'sample'``.
866 Returns
867 -------
868 eta : (n,) array
869 Transformed outcomes.
870 """
872 hp = self._gethp(hp, rng) 1bcde
873 return self._from_data(hp, y) 1bcde
875 def to_data(self, eta, *, hp='map', rng=None): 1fabcde
876 """
877 Convert the regression variable :math:`\\eta` to outcomes :math:`y`.
879 Parameters
880 ----------
881 eta : (n,) array
882 Transformed outcomes.
883 hp : str or dict
884 The hyperparameters to use. If ``'map'`` (default), use the marginal
885 maximum a posteriori. If ``'sample'``, sample hyperparameters from
886 the posterior. If a dict, use the given hyperparameters.
887 rng : numpy.random.Generator, optional
888 Random number generator, used if ``hp == 'sample'``.
890 Returns
891 -------
892 y : (n,) array
893 Outcomes.
894 """
896 hp = self._gethp(hp, rng) 1bcde
897 return self._to_data(hp, eta) 1bcde
899 @classmethod 1fabcde
900 def _to_structured(cls, x, *, check_numerical=True): 1fabcde
902 # convert to StructuredArray
903 if hasattr(x, 'columns'): 1abcde
904 x = _array.StructuredArray.from_dataframe(x) 1a
905 elif hasattr(x, 'to_numpy'): 905 ↛ 906line 905 didn't jump to line 906 because the condition on line 905 was never true1abcde
906 x = _array.StructuredArray.from_dict({
907 'f0' if x.name is None else x.name: x.to_numpy()
908 })
909 elif x.dtype.names is None: 1abcde
910 x = _array.unstructured_to_structured(x) 1bcde
911 else:
912 x = _array.StructuredArray(x) 1a
914 # check fields are numerical, for BART
915 if check_numerical: 915 ↛ 923line 915 didn't jump to line 923 because the condition on line 915 was always true1abcde
916 assert x.ndim == 1 1abcde
917 assert x.size > len(x.dtype) 1abcde
918 def check_numerical(path, dtype): 1abcde
919 if not numpy.issubdtype(dtype, numpy.number): 919 ↛ 920line 919 didn't jump to line 920 because the condition on line 919 was never true1abcde
920 raise TypeError(f'covariate `{path}` is not numerical')
921 cls._walk_dtype(x.dtype, check_numerical) 1abcde
923 return x 1abcde
925 @staticmethod 1fabcde
926 def _to_vector(x): 1fabcde
927 if hasattr(x, 'columns'): # dataframe 927 ↛ 928line 927 didn't jump to line 928 because the condition on line 927 was never true1abcde
928 x = x.to_numpy().squeeze(axis=1)
929 elif hasattr(x, 'to_numpy'): # series (dataframe column) 1abcde
930 x = x.to_numpy() 1a
931 x = jnp.asarray(x) 1abcde
932 if x.ndim != 1: 932 ↛ 933line 932 didn't jump to line 933 because the condition on line 932 was never true1abcde
933 raise ValueError(f'array is not 1d vector, ndim={x.ndim}')
934 return x 1abcde
936 @classmethod 1fabcde
937 def _walk_dtype(cls, dtype, task, path=None): 1fabcde
938 if dtype.names is None: 1abcde
939 task(path, dtype) 1abcde
940 else:
941 for name in dtype.names: 1abcde
942 subpath = name if path is None else path + ':' + name 1abcde
943 cls._walk_dtype(dtype[name], task, subpath) 1abcde
945 @staticmethod 1fabcde
946 def _toindices(x, splits): 1fabcde
947 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde
948 dtype = cast(x.dtype, ix.dtype) 1abcde
949 return _array.unstructured_to_structured(ix, dtype=dtype) 1abcde
951 def __repr__(self): 1fabcde
953 with _gvarext.gvar_format(): 1a
954 if hasattr(self.m, 'sdev'): 954 ↛ 955line 954 didn't jump to line 955 because the condition on line 954 was never true1a
955 m = str(self.m)
956 else:
957 m = f'{self.m:.3g}' 1a
959 n = self.fit.gpfactorykw['y'].size 1a
960 p_mu = _array._nd(self.fit.gpfactorykw['i_mu']['x'].dtype) 1a
961 x_tau = self.fit.gpfactorykw['i_tau'] 1a
962 x_aux = self.fit.gpfactorykw['x_aux'] 1a
964 out = f"""\ 1a
965Data:
966 n = {n}"""
968 if x_tau is None: 968 ↛ 972line 968 didn't jump to line 972 because the condition on line 968 was always true1a
969 out += f""" 1a
970 p = {p_mu}"""
971 else:
972 p_tau = _array._nd(x_tau['x'].dtype)
973 out += f"""
974 p_mu/tau = {p_mu}, {p_tau}"""
976 if x_aux is not None: 976 ↛ 977line 976 didn't jump to line 977 because the condition on line 976 was never true1a
977 p_aux = _array._nd(x_aux['x'].dtype)
978 out += f"""
979 p_aux = {p_aux}"""
981 out += f""" 1a
982Hyperparameter posterior:
983 m = {m}
984 z_0 = {self.z_0}
985 alpha_mu/tau = {self.alpha_mu} {self.alpha_tau}
986 beta_mu/tau = {self.beta_mu} {self.beta_tau}
987 lambda_mu/tau = {self.lambda_mu} {self.lambda_tau}"""
989 weights = self.fit.gpfactorykw['weights'] 1a
990 if weights is None: 990 ↛ 995line 990 didn't jump to line 995 because the condition on line 990 was always true1a
991 out += f""" 1a
992 sigma = {self.sigma}"""
994 else:
995 weights = numpy.array(weights) # to avoid jax taking over the ops
996 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights))
997 out += f"""
998 sqrt(mean(sigma^2/w)) = {avgsigma}
999 sigma = {self.sigma}"""
1001 out += """ 1a
1002Meaning of hyperparameters:
1003 mu(x) = reference outcome level
1004 tau(x) = effect of the treatment
1005 z_0 in (0, 1): reference treatment level
1006 z_0 -> 0: mu is the model of the untreated
1007 z_0 -> 1: mu is the model of the treated
1008 alpha in (0, 1)
1009 alpha -> 0: constant function
1010 alpha -> 1: no constraints on the function
1011 beta in (0, ∞)
1012 beta -> 0: no constraints on the function
1013 beta -> ∞: no interactions, f(x) = f1(x1) + f2(x2) + ...
1014 lambda in (0, ∞): standard deviation of function
1015 lambda small: confident extrapolation
1016 lambda large: conservative extrapolation
1017 sigma in (0, ∞): standard deviation of i.i.d. error"""
1019 return _utils.top_bottom_rule('BCF', out) 1a
1021 # TODO print user parameters, applying transformations. Copy the dict and use .pop() to remove the predefined params as they are printed.
1023 def _get_transf(self, *, transf, y, weights): 1fabcde
1025 from_datas = [] 1abcde
1026 to_datas = [] 1abcde
1027 hypers = {} 1abcde
1029 if transf is None: 1029 ↛ 1030line 1029 didn't jump to line 1030 because the condition on line 1029 was never true1abcde
1030 transf = []
1031 elif isinstance(transf, list): 1abcde
1032 name = lambda n: f'transf{i}_{n}' 1abcde
1033 else:
1034 name = lambda n: n 1bcde
1035 transf = [transf] 1bcde
1037 for i, tr in enumerate(transf): 1abcde
1039 hyper = {} 1abcde
1041 if not isinstance(tr, str): 1041 ↛ 1043line 1041 didn't jump to line 1043 because the condition on line 1041 was never true1abcde
1043 from_data, to_data = tr
1045 elif tr == 'standardize': 1abcde
1047 if i > 0: 1047 ↛ 1048line 1047 didn't jump to line 1048 because the condition on line 1047 was never true1abcde
1048 warnings.warn('standardization applied after other '
1049 'transformations: standardization always uses the '
1050 'initial data mean and standard deviation, so it may '
1051 'not work as intended')
1053 # It's not possible to overcome this limitation if one wants
1054 # to stick to transformations that act on one point at a
1055 # time to make them generalizable out of sample.
1057 if weights is None: 1057 ↛ 1061line 1057 didn't jump to line 1061 because the condition on line 1057 was always true1abcde
1058 loc = jnp.mean(y) 1abcde
1059 scale = jnp.std(y) 1abcde
1060 else:
1061 loc = jnp.average(y, weights=weights)
1062 scale = jnp.sqrt(jnp.average((y - loc) ** 2, weights=weights))
1064 def from_data(hp, y): 1abcde
1065 return (y - loc) / scale 1abcde
1066 def to_data(hp, eta): 1abcde
1067 return loc + scale * eta 1abcde
1069 elif tr == 'yeojohnson': 1069 ↛ 1078line 1069 didn't jump to line 1078 because the condition on line 1069 was always true1abcde
1071 def from_data(hp, y): 1abcde
1072 return yeojohnson(y, hp[name('lambda_yj')]) 1abcde
1073 def to_data(hp, eta): 1abcde
1074 return yeojohnson_inverse(eta, hp[name('lambda_yj')]) 1abcde
1075 hyper[name('lambda_yj')] = 2 * copula.beta(2, 2) 1abcde
1077 else:
1078 raise KeyError(tr)
1080 from_datas.append(from_data) 1abcde
1081 to_datas.append(to_data) 1abcde
1082 hypers.update(hyper) 1abcde
1084 if transf: 1084 ↛ 1094line 1084 didn't jump to line 1094 because the condition on line 1084 was always true1abcde
1085 def from_data(hp, y): 1abcde
1086 for fd in from_datas: 1abcde
1087 y = fd(hp, y) 1abcde
1088 return y 1abcde
1089 def to_data(hp, eta): 1abcde
1090 for td in reversed(to_datas): 1abcde
1091 eta = td(hp, eta) 1abcde
1092 return eta 1abcde
1093 else:
1094 from_data = lambda hp, y: y
1095 to_data = lambda hp, eta: eta
1097 from_data_grad = _jaxext.elementwise_grad(from_data, 1) 1abcde
1098 def loss(hp): 1abcde
1099 return -jnp.sum(jnp.log(from_data_grad(hp, y))) 1abcde
1101 hypers = copula.makedict(hypers) 1abcde
1103 return from_data, to_data, loss, hypers 1abcde
1105def yeojohnson(x, lmbda): 1fabcde
1106 """ Yeo-Johnson transformation with lamda != 0, 2 """
1107 return jnp.where( 1abcde
1108 x >= 0,
1109 (jnp.power(x + 1, lmbda) - 1) / lmbda,
1110 -((jnp.power(-x + 1, 2 - lmbda) - 1) / (2 - lmbda))
1111 )
1113 # TODO
1114 # - rewrite the cases with expm1, log1p, etc. to make them accurate
1115 # - split the cases into lambda 0/2
1116 # - make custom_jvps for the singular points to define derivatives w.r.t.
1117 # lambda even though it does not appear in the expression
1118 # - add unit tests that check gradients with finite differences
1120def yeojohnson_inverse(y, lmbda): 1fabcde
1121 return jnp.where( 1abcde
1122 y >= 0,
1123 jnp.power(y * lmbda + 1, 1 / lmbda) - 1,
1124 -jnp.power(-(2 - lmbda) * y + 1, 1 / (2 - lmbda)) + 1
1125 )