Coverage for src/lsqfitgp/bayestree/_bcf.py: 80%
360 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/bayestree/_bcf.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1fabcde
21import warnings 1fabcde
23import numpy 1fabcde
24from scipy import stats 1fabcde
25from jax import numpy as jnp 1fabcde
26import jax 1fabcde
27import gvar 1fabcde
29from .. import copula 1fabcde
30from .. import _kernels 1fabcde
31from .. import _fit 1fabcde
32from .. import _array 1fabcde
33from .. import _GP 1fabcde
34from .. import _fastraniter 1fabcde
35from .. import _jaxext 1fabcde
36from .. import _gvarext 1fabcde
37from .. import _utils 1fabcde
39# TODO add methods or options to do causal inference stuff, e.g., impute missing
40# outcomes, or ate, att, cate, catt, sate, satt. Remember that the effect may
41# also depend on aux. See bartCause, possibly copy its naming.
43def _recursive_cast(dtype, default, mapping): 1fabcde
44 if dtype in mapping: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1abcde
45 return mapping[dtype]
46 elif dtype.names is not None: 1abcde
47 return numpy.dtype([ 1abcde
48 (name, _recursive_cast(dtype[name], default, mapping))
49 for name in dtype.names
50 ])
51 elif dtype.subdtype is not None: 51 ↛ 53line 51 didn't jump to line 53 because the condition on line 51 was never true1abcde
52 # note: has names => does not have subdtype
53 return numpy.dtype((_recursive_cast(dtype.base, default, mapping), dtype.shape))
54 elif default is None: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true1abcde
55 return dtype
56 else:
57 return default 1abcde
59def cast(dtype, default, mapping={}): 1fabcde
60 """
61 Recursively cast a numpy data type.
63 Parameters
64 ----------
65 dtype : dtype
66 The data type to cast.
67 default : dtype or None
68 The leaf fields of `dtype` are casted to `default`, which can be
69 structured, unless they appear in `mapping`. If None, dtypes not in
70 `mapping` are left unchanged.
71 mapping : dict
72 A dictionary from dtypes to dtypes, indicating specific casting rules.
73 The dtypes can be structured, a match of a structured dtype takes
74 precedence over matches in its leaves, and the converted dtype is not
75 further searched for matches.
77 Returns
78 -------
79 casted_dtype : dtype
80 The casted version of `dtype`. May not have the same structure if
81 `mapping` contains structured dtypes.
82 """
83 mapping = {numpy.dtype(k): numpy.dtype(v) for k, v in mapping.items()} 1abcde
84 default = None if default is None else numpy.dtype(default) 1abcde
85 return _recursive_cast(numpy.dtype(dtype), default, mapping) 1abcde
87 # TODO
88 # - move this to generic utils
89 # - make unit tests
91class bcf: 1fabcde
93 def __init__(self, *, 1fabcde
94 y,
95 z,
96 x_mu,
97 x_tau=None,
98 pihat,
99 include_pi='mu',
100 weights=None,
101 fitkw={},
102 kernelkw_mu={},
103 kernelkw_tau={},
104 marginalize_mean=True,
105 gpaux=None,
106 x_aux=None,
107 otherhp={},
108 transf='standardize',
109 ):
110 r"""
111 Nonparametric Bayesian regression with a GP version of BCF.
113 BCF (Bayesian Causal Forests) is a regression method for observational
114 causal inference studies introduced in [1]_ based on a pair of BART
115 models.
117 This class evaluates a Gaussian process regression with a kernel which
118 accurately approximates BCF in the infinite trees limit of each BART
119 model. The hyperparameters are optimized to their marginal MAP.
121 The model is (loosely, see notes below) :math:`y = \mu(x) + z\tau(x)`,
122 so :math:`\tau(x)` is the expected causal effect of :math:`z` on
123 :math:`y` at location :math:`x`.
125 Parameters
126 ----------
127 y : (n,) array, series or dataframe
128 Outcome.
129 z : (n,) array, series or dataframe
130 Binary treatment status: 0 control group, 1 treatment group.
131 x_mu : (n, p) array, series or dataframe
132 Covariates for the :math:`\mu` model.
133 x_tau : (n, q) array, series or dataframe, optional
134 Covariates for the :math:`\tau` model. If not specified, use `x_mu`.
135 pihat : (n,) array, series or dataframe
136 Estimated propensity score, i.e., P(Z=1|X).
137 include_pi : {'mu', 'tau', 'both'}, optional
138 Whether to include the propensity score in the :math:`\mu` model,
139 the :math:`\tau` model, or both. Default is ``'mu'``.
140 weights : (n,) array, series or dataframe
141 Weights used to rescale the error variance (as 1 / weight).
142 fitkw : dict
143 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides
144 the defaults.
145 kernelkw_mu, kernelkw_tau : dict
146 Additional arguments passed to `~lsqfitgp.BART` for each model,
147 overrides the defaults.
148 marginalize_mean : bool
149 If True (default), marginalize the intercept of the model.
150 gpaux : callable, optional
151 If specified, this function is called with a pair ``(hp, gp)``,
152 where ``hp`` is a dictionary of hyperparameters, and ``gp`` is a
153 `~lsqfitgp.GP` object under construction, and is expected to return
154 a modified ``gp`` with a new process named ``'aux'`` defined with
155 `~lsqfitgp.GP.defproc` or similar. The process is added to the
156 regression model. The input to the process is a structured array
157 with fields:
159 'train' : bool
160 Indicates whether the data is training set (the one passed on
161 initialization) or test set (the one passed to `pred` or `gp`).
162 'i' : int
163 Index of the flattened array.
164 'z' : int
165 Treatment status.
166 'mu', 'tau' : structured
167 The values in `x_mu` and `x_tau`, converted to indices according
168 to the BART grids. Where `pihat` has been added, there are two
169 subfields: ``'x'`` which contains the covariates, and
170 ``'pihat'``, the latter expressed in indices as well.
171 'pihat' : float
172 The `pihat` argument. Contrary to the subfield included under
173 ``'mu'`` and/or ``'tau'``, this field contains the original
174 values.
175 'aux' : structured
176 The values in `x_aux`, if specified.
178 x_aux : (n, k) array, series or dataframe, optional
179 Additional covariates for the ``'aux'`` process.
180 otherhp : dictionary of gvar
181 A dictionary with the prior of arbitrary additional hyperpameters,
182 intended to be used by ``gpaux`` or ``transf``.
183 transf : (list of) str or pair of callable
184 Data transformation. Either a string indicating a pre-defined
185 transformation, or a pair ``(from_data, to_data)``, two functions
186 with signatures ``from_data(hp, y) -> eta`` and ``to_data(hp, eta)
187 -> y``, where ``eta`` is the value to which the model is fit, and
188 ``hp`` is the dictionary of hyperparameters. The functions must be
189 ufuncs and one the inverse of the other w.r.t. the second parameter.
190 ``from_data`` must be derivable with `jax` w.r.t. ``y``.
192 If a list of such specifications is provided, the transformations
193 are applied in order, with the first one being the outermost, i.e.,
194 the one applied first to the data.
196 If a transformation uses additional hyperparameters, either
197 predefined automatically or passed by the user through `otherhp`,
198 they are inferred with the rest of the hyperparameters.
200 The pre-defined transformations are:
202 'standardize' (default)
203 eta = (y - mean(train_y)) / sdev(train_y)
204 'yeojohnson'
205 The Yeo-Johnson transformation [2]_ to reduce skewness. The
206 :math:`\lambda` parameter is bounded in :math:`(0, 2)`
207 for implementation convenience, this restriction may be lifted
208 in future versions.
210 Notes
211 -----
212 The regression model is:
214 .. math::
215 \eta_i = g(y_i; \ldots) &= m + {} \\
216 &\phantom{{}={}} +
217 \lambda_\mu
218 \mu(\mathbf x^\mu_i, \hat\pi_i?) + {} \\
219 &\phantom{{}={}} +
220 \lambda_\tau
221 \tau(\mathbf x^\tau_i, \hat\pi_i?) (z_i - z_0) + {} \\
222 &\phantom{{}={}} +
223 \mathrm{aux}(i, z_i, \mathbf x^\mu_i, \mathbf x^\tau_i,
224 \hat\pi_i, \mathbf x^\text{aux}_i) + {} \\
225 &\phantom{{}={}} +
226 \varepsilon_i, \\
227 \varepsilon_i &\sim
228 N(0, \sigma^2 / w_i), \\
229 m &\sim N(0, 1), \\
230 \log \sigma^2 &\sim N(\log\bar w, 4), \\
231 \lambda_\mu
232 &\sim \mathrm{HalfCauchy}(2), \\
233 \lambda_\tau
234 &\sim \mathrm{HalfNormal}(1.48), \\
235 \mu &\sim \mathrm{GP}(0,
236 \mathrm{BART}(\alpha_\mu, \beta_\mu) ), \\
237 \tau &\sim \mathrm{GP}(0,
238 \mathrm{BART}(\alpha_\tau, \beta_\tau) ), \\
239 \mathrm{aux} & \sim \mathrm{GP}(0, \text{<user defined>}), \\
240 \alpha_\mu, \alpha_\tau &\sim \mathrm{Beta}(2, 1), \\
241 \beta_\mu, \beta_\tau &\sim \mathrm{InvGamma}(1, 1), \\
242 z_0 &\sim U(0, 1),
244 To make the inference, :math:`(\mu, \tau, \boldsymbol\varepsilon, m,
245 \mathrm{aux})` are marginalized analytically, and the marginal posterior
246 mode of
247 :math:`(\sigma, \lambda_*, \alpha_*, \beta_*, z_0, \ldots)` is found by
248 numerical minimization, after transforming them to express their prior
249 as a Gaussian copula. Their marginal posterior covariance matrix is
250 estimated with an approximation of the hessian inverse. See
251 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize
252 this procedure.
254 The tree splitting grid of the BART kernel is set using quantiles of the
255 observed covariates. This corresponds to settings ``usequants=True``,
256 ``numcut=inf`` in the R packages BayesTree and BART. Use the parameters
257 `kernelkw_mu` and `kernelkw_tau` to customize the grids.
259 The difference between the regression model evaluated at :math:`Z=1` vs.
260 :math:`Z=0` can be interpreted as the causal effect :math:`Z \rightarrow
261 Y` if the unconfoundedness assumption is made:
263 .. math::
264 \{Y(Z=0), Y(Z=1)\} \perp\!\!\!\perp Z \mid X.
266 In practical terms, this holds when:
268 1) :math:`X` are pre-treatment variables, i.e., they represent
269 quantities causally upstream of :math:`Z`.
271 2) :math:`X` are sufficient to adjust for all common causes of
272 :math:`Z` and :math:`Y`, such that the only remaining difference
273 is the causal effect and not just a correlation.
275 Here :math:`X` consists in `x_tau`, `x_mu` and `x_aux`. However these
276 arrays may also be used to pass "technical" values used to set up the
277 model, that do not satisfy the uncounfoundedness assumption, if you know
278 what you are doing.
280 Attributes
281 ----------
282 m : float or gvar
283 The prior mean :math:`m`.
284 sigma : gvar
285 The error term standard deviation :math:`\sigma`. If there are
286 weights, the sdev for each unit is obtained dividing ``sigma`` by
287 sqrt(weight).
288 alpha_mu, alpha_tau : gvar
289 The numerator of the tree spawn probability :math:`\alpha_*` (named
290 ``base`` in R bcf).
291 beta_mu, beta_tau : gvar
292 The depth exponent of the tree spawn probability :math:`\beta_*`
293 (named ``power`` in R bcf).
294 lambda_mu, lambda_tau : gvar
295 The prior standard deviation :math:`\lambda_*`.
296 z_0 : gvar
297 The treatment coding parameter.
298 fit : empbayes_fit
299 The hyperparameters fit object.
301 Methods
302 -------
303 gp :
304 Create a GP object.
305 data :
306 Creates the dictionary to be passed to `GP.pred` to represent data.
307 pred :
308 Evaluate the regression function at given locations.
309 from_data :
310 Convert :math:`y` to :math:`\eta`.
311 to_data :
312 Convert :math:`\eta` to :math:`y`.
314 See also
315 --------
316 lsqfitgp.BART
318 References
319 ----------
320 .. [1] P. Richard Hahn, Jared S. Murray, Carlos M. Carvalho "Bayesian
321 Regression Tree Models for Causal Inference: Regularization,
322 Confounding, and Heterogeneous Effects (with Discussion)," Bayesian
323 Analysis 15(3), 965-1056, September 2020,
324 https://doi.org/10.1214/19-BA1195
325 .. [2] Yeo, In-Kwon; Johnson, Richard A. (2000). "A New Family of Power
326 Transformations to Improve Normality or Symmetry". Biometrika. 87
327 (4): 954–959. https://doi.org/10.1093/biomet/87.4.954
328 """
330 # convert covariates to StructuredArray
331 x_mu = self._to_structured(x_mu) 1abcde
332 if x_tau is not None: 1abcde
333 x_tau = self._to_structured(x_tau) 1bcde
334 assert x_tau.shape == x_mu.shape 1bcde
335 if x_aux is not None: 1abcde
336 x_aux = self._to_structured(x_aux) 1bcde
337 assert x_aux.shape == x_mu.shape 1bcde
339 # convert outcomes, treatment, propensity score, weights to 1d arrays
340 y = self._to_vector(y) 1abcde
341 z = self._to_vector(z) 1abcde
342 pihat = self._to_vector(pihat) 1abcde
343 assert y.shape == z.shape == pihat.shape == x_mu.shape 1abcde
344 if weights is not None: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true1abcde
345 weights = self._to_vector(weights)
346 assert weights.shape == x_mu.shape
348 # check include_pi
349 if include_pi not in ('mu', 'tau', 'both'): 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true1abcde
350 raise KeyError(f'invalid value include_pi={include_pi!r}')
351 self._include_pi = include_pi 1abcde
353 # add pihat to covariates
354 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1abcde
356 # grid and indices
357 splits_mu = _kernels.BART.splits_from_coord(x_mu) 1abcde
358 i_mu = self._toindices(x_mu, splits_mu) 1abcde
359 if x_tau is None: 1abcde
360 splits_tau = splits_mu 1abcde
361 i_tau = None 1abcde
362 else:
363 splits_tau = _kernels.BART.splits_from_coord(x_tau) 1bcde
364 i_tau = self._toindices(x_tau, splits_tau) 1bcde
366 # get functions for data transformation
367 from_data, to_data, transfloss, transfhp = self._get_transf( 1abcde
368 transf=transf, weights=weights, y=y)
370 # scale of error variance
371 logsigma2_loc = 0 if weights is None else numpy.log(jnp.mean(weights)) 1abcde
373 # prior on hyperparams
374 hyperprior = copula.makedict({ 1abcde
375 'm': gvar.gvar(0, 1),
376 'sigma^2': copula.lognorm(logsigma2_loc, 2),
377 'lambda_mu': copula.halfcauchy(2),
378 'lambda_tau': copula.halfnorm(1.48),
379 'alpha_mu': copula.beta(2, 1),
380 'alpha_tau': copula.beta(2, 1),
381 'beta_mu': copula.invgamma(1, 1),
382 'beta_tau': copula.invgamma(1, 1),
383 'z_0': copula.uniform(0, 1),
384 })
386 # remove explicit mean parameter if it's baked into the Gaussian process
387 if marginalize_mean: 1abcde
388 hyperprior.pop('m') 1abcde
390 # add data transformation and user hyperparameters
391 def update_hyperparams(new, newname, raises): 1abcde
392 new = gvar.BufferDict(new) 1abcde
393 for key in new.all_keys(): 1abcde
394 if hyperprior.has_dictkey(key): 394 ↛ 395line 394 didn't jump to line 395 because the condition on line 394 was never true1abcde
395 message = f'{newname} hyperparameter {key!r} overrides existing one'
396 if raises:
397 raise ValueError(message)
398 else:
399 warnings.warn(message)
400 hyperprior.update(new) 1abcde
401 update_hyperparams(transfhp, 'data transformation', True) 1abcde
402 # the hypers handed by _get_transf are not allowed to override
403 update_hyperparams(otherhp, 'user', False) 1abcde
405 # GP factory
406 def gpfactory(hp, *, z, i_mu, i_tau, pihat, x_aux, weights, 1abcde
407 splits_mu, splits_tau, **_):
409 # TODO maybe I should pass kernelkw_* as arguments, but they may not
410 # be jittable. I need jitkw in empbayes_fit for that.
412 kw_overridable = dict( 1abcde
413 maxd=10,
414 reset=[2, 4, 6, 8],
415 intercept=False,
416 )
417 kw_not_overridable = dict(indices=True) 1abcde
419 gp = _GP.GP(checkpos=False, checksym=False, solver='chol') 1abcde
421 for name, kernelkw in dict(mu=kernelkw_mu, tau=kernelkw_tau).items(): 1abcde
422 kw = dict( 1abcde
423 alpha=hp[f'alpha_{name}'],
424 beta=hp[f'beta_{name}'],
425 dim=name,
426 splits=eval(f'splits_{name}'),
427 **kw_overridable,
428 )
429 kw.update(kernelkw) 1abcde
430 kernel = _kernels.BART(**kw, **kw_not_overridable) 1abcde
431 kernel *= hp[f'lambda_{name}'] ** 2 1abcde
433 gp = gp.defproc(name, kernel) 1abcde
435 if 'm' in hp: 1abcde
436 kernel_mean = 0 * _kernels.Constant() 1bcde
437 else:
438 kernel_mean = _kernels.Constant() 1abcde
439 gp = gp.defproc('m', kernel_mean) 1abcde
441 if gpaux is None: 1abcde
442 gp = gp.defproc('aux', 0 * _kernels.Constant()) 1abcde
443 else:
444 gp = gpaux(hp, gp) 1bcde
446 gp = gp.deflintransf( 1abcde
447 gp.DefaultProcess,
448 lambda m, mu, tau, aux: lambda x:
449 m(x) + mu(x) + tau(x) * (x['z'] - hp['z_0']) + aux(x),
450 ['m', 'mu', 'tau', 'aux'],
451 )
453 x = self._join_points(True, z, i_mu, i_tau, pihat, x_aux) 1abcde
454 gp = gp.addx(x, 'trainmean') 1abcde
455 errcov = self._error_cov(hp, weights, x) 1abcde
456 return (gp 1abcde
457 .addcov(errcov, 'trainnoise')
458 .addtransf({'trainmean': 1, 'trainnoise': 1}, 'train')
459 )
461 # data factory
462 def data(hp, *, y, **_): 1abcde
463 return {'train': from_data(hp, y) - hp.get('m', 0)} 1abcde
465 # fit hyperparameters
466 gpkw = dict( 1abcde
467 y=y,
468 z=z,
469 i_mu=i_mu,
470 i_tau=i_tau,
471 pihat=pihat,
472 x_aux=x_aux,
473 weights=weights,
474 splits_mu=splits_mu,
475 splits_tau=splits_tau,
476 )
477 options = dict( 1abcde
478 verbosity=3,
479 minkw=dict(method='l-bfgs-b', options=dict(maxls=4, maxiter=100)),
480 mlkw=dict(epsrel=0),
481 forward=True,
482 gpfactorykw=gpkw,
483 additional_loss=transfloss,
484 )
485 options.update(fitkw) 1abcde
486 fit = _fit.empbayes_fit(hyperprior, gpfactory, data, **options) 1abcde
488 # extract hyperparameters from minimization result
489 self.m = fit.p.get('m', 0) 1abcde
490 self.sigma = gvar.sqrt(fit.p['sigma^2']) 1abcde
491 self.lambda_mu = fit.p['lambda_mu'] 1abcde
492 self.lambda_tau = fit.p['lambda_tau'] 1abcde
493 self.alpha_mu = fit.p['alpha_mu'] 1abcde
494 self.alpha_tau = fit.p['alpha_tau'] 1abcde
495 self.beta_mu = fit.p['beta_mu'] 1abcde
496 self.beta_tau = fit.p['beta_tau'] 1abcde
497 self.z_0 = fit.p['z_0'] 1abcde
499 # save other attributes
500 self.fit = fit 1abcde
501 self._from_data = from_data 1abcde
502 self._to_data = to_data 1abcde
504 def _append_pihat(self, x_mu, x_tau, pihat): 1fabcde
505 ip = self._include_pi 1abcde
506 if ip == 'mu' or ip == 'both': 1abcde
507 x_mu = _array.StructuredArray.from_dict(dict( 1abcde
508 x=x_mu,
509 pihat=pihat,
510 ))
511 if x_tau is not None and (ip == 'tau' or ip == 'both'): 1abcde
512 x_tau = _array.StructuredArray.from_dict(dict( 1bcde
513 x=x_tau,
514 pihat=pihat,
515 ))
516 return x_mu, x_tau 1abcde
518 @staticmethod 1fabcde
519 def _join_points(train, z, i_mu, i_tau, pihat, x_aux): 1fabcde
520 """ join covariates into a single StructuredArray """
521 columns = dict( 1abcde
522 train=jnp.broadcast_to(bool(train), z.shape),
523 i=jnp.arange(z.size).reshape(z.shape),
524 z=z,
525 mu=i_mu,
526 tau=i_mu if i_tau is None else i_tau,
527 pihat=pihat,
528 )
529 if x_aux is not None: 1abcde
530 columns.update(aux=x_aux) 1bcde
531 return _array.StructuredArray.from_dict(columns) 1abcde
533 @staticmethod 1fabcde
534 def _error_cov(hp, weights, x): 1fabcde
535 """ fill error covariance matrix """
536 if weights is None: 536 ↛ 539line 536 didn't jump to line 539 because the condition on line 536 was always true1abcde
537 error_var = jnp.broadcast_to(hp['sigma^2'], len(x)) 1abcde
538 else:
539 error_var = hp['sigma^2'] / weights
540 return jnp.diag(error_var) 1abcde
542 def _gethp(self, hp, rng): 1fabcde
543 if not isinstance(hp, str): 543 ↛ 544line 543 didn't jump to line 544 because the condition on line 543 was never true1abcde
544 return hp
545 elif hp == 'map': 1abcde
546 return self.fit.pmean 1abcde
547 elif hp == 'sample': 547 ↛ 550line 547 didn't jump to line 550 because the condition on line 547 was always true1abcde
548 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng) 1abcde
549 else:
550 raise KeyError(hp)
552 def gp(self, *, hp='map', z=None, x_mu=None, x_tau=None, pihat=None, 1fabcde
553 x_aux=None, weights=None, rng=None):
554 """
555 Create a Gaussian process with the fitted hyperparameters.
557 Parameters
558 ----------
559 hp : str or dict
560 The hyperparameters to use. If ``'map'`` (default), use the marginal
561 maximum a posteriori. If ``'sample'``, sample hyperparameters from
562 the posterior. If a dict, use the given hyperparameters.
563 z : (m,) array, series or dataframe, optional
564 Treatment status at test points. If specified, also `x_mu`, `pihat`,
565 `x_tau` and `x_aux` (the latter two if and only also specified at
566 initialization) must be specified.
567 x_mu : (m, p) array, series or dataframe, optional
568 Control model covariates at test points.
569 x_tau : (m, q) array, series or dataframe, optional
570 Moderating model covariates at test points.
571 pihat : (m,) array, series or dataframe, optional
572 Estimated propensity score at test points.
573 x_aux : (m, k) array, series or dataframe, optional
574 Additional covariates for the ``'aux'`` process.
575 weights : (m,) array, series or dataframe, optional
576 Weights for the error variance on the test points.
577 rng : numpy.random.Generator, optional
578 Random number generator, used if ``hp == 'sample'``.
580 Returns
581 -------
582 gp : GP
583 A centered Gaussian process object. To add the mean, use the `m`
584 attribute of the `bcf` object. The keys of the GP are ``'@mean'``,
585 ``'@noise'``, and ``'@'``, where the "@" stands either for 'train'
586 or 'test', and @ = @mean + @noise.
588 This Gaussian process is defined on the transformed data ``eta``.
589 """
591 hp = self._gethp(hp, rng)
592 return self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw)
594 def _gp(self, hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw): 1fabcde
595 """
596 Internal function to create the GP object. This function must work
597 both if the arguments are user-provided and need to be checked and
598 converted to standard format, or if they are traced jax values.
599 """
601 # create GP object
602 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde
604 # add test points
605 if z is not None: 1abcde
607 # check presence/absence of arguments is coherent
608 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1a
610 # check treatment and propensity score
611 z = self._to_vector(z) 1a
612 pihat = self._to_vector(pihat) 1a
613 assert pihat.shape == z.shape 1a
615 # check weights
616 if weights is not None: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true1a
617 weights = self._to_vector(weights)
618 assert weights.shape == z.shape
620 # add propensity score to covariates
621 x_mu = self._to_structured(x_mu) 1a
622 assert x_mu.shape == z.shape 1a
623 if x_tau is not None: 623 ↛ 624line 623 didn't jump to line 624 because the condition on line 623 was never true1a
624 x_tau = self._to_structured(x_tau)
625 assert x_tau.shape == z.shape
626 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1a
628 # convert covariates to indices
629 i_mu = self._toindices(x_mu, gpfactorykw['splits_mu']) 1a
630 assert i_mu.dtype == gpfactorykw['i_mu'].dtype 1a
631 if x_tau is not None: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true1a
632 i_tau = self._toindices(x_tau, gpfactorykw['splits_tau'])
633 assert i_tau.dtype == gpfactorykw['i_tau'].dtype
634 else:
635 i_tau = None 1a
637 # check auxiliary points
638 if x_aux is not None: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true1a
639 x_aux = self._to_structured(x_aux)
641 # add test points
642 x = self._join_points(False, z, i_mu, i_tau, pihat, x_aux) 1a
643 gp = gp.addx(x, 'testmean') 1a
644 errcov = self._error_cov(hp, weights, x) 1a
645 gp = (gp 1a
646 .addcov(errcov, 'testnoise')
647 .addtransf({'testmean': 1, 'testnoise': 1}, 'test')
648 )
650 return gp 1abcde
652 def _check_coherent_covariates(self, z, x_mu, x_tau, pihat, x_aux): 1fabcde
653 if z is None: 1abcde
654 assert x_mu is None 1bcde
655 assert x_tau is None 1bcde
656 assert pihat is None 1bcde
657 assert x_aux is None 1bcde
658 else:
659 assert x_mu is not None 1a
660 assert pihat is not None 1a
661 train_tau = self.fit.gpfactorykw['i_tau'] 1a
662 if x_tau is None: 662 ↛ 665line 662 didn't jump to line 665 because the condition on line 662 was always true1a
663 assert train_tau is None 1a
664 else:
665 assert train_tau is not None
666 train_aux = self.fit.gpfactorykw['x_aux'] 1a
667 if x_aux is None: 667 ↛ 670line 667 didn't jump to line 670 because the condition on line 667 was always true1a
668 assert train_aux is None 1a
669 else:
670 assert train_aux is not None
672 def data(self, *, hp='map', rng=None): 1fabcde
673 """
674 Get the data to be passed to `GP.pred` on a GP object returned by `gp`.
676 Parameters
677 ----------
678 hp : str or dict
679 The hyperparameters to use. If ``'map'`` (default), use the marginal
680 maximum a posteriori. If ``'sample'``, sample hyperparameters from
681 the posterior. If a dict, use the given hyperparameters.
682 rng : numpy.random.Generator, optional
683 Random number generator, used if ``hp == 'sample'``.
685 Returns
686 -------
687 data : dict
688 A dictionary representing ``eta`` in the format required by the
689 `GP.pred` method.
690 """
692 hp = self._gethp(hp, rng)
693 return self.fit.data(hp, **self.fit.gpfactorykw)
695 def pred(self, *, hp='map', error=False, z=None, x_mu=None, x_tau=None, 1fabcde
696 pihat=None, x_aux=None, weights=None, transformed=True, samples=None,
697 gvars=False, rng=None):
698 r"""
699 Predict the transformed outcome at given locations.
701 Parameters
702 ----------
703 hp : str or dict
704 The hyperparameters to use. If ``'map'`` (default), use the marginal
705 maximum a posteriori. If ``'sample'``, sample hyperparameters from
706 the posterior. If a dict, use the given hyperparameters.
707 error : bool, default False
708 If ``False``, make a prediction for the latent mean. If ``True``,
709 add the error term.
710 z : (m,) array, series or dataframe, optional
711 Treatment status at test points. If specified, also `x_mu`, `pihat`,
712 `x_tau` and `x_aux` (the latter two if and only also specified at
713 initialization) must be specified.
714 x_mu : (m, p) array, series or dataframe, optional
715 :math:`\mu` model covariates at test points.
716 x_tau : (m, q) array, series or dataframe, optional
717 :math:`\tau` model covariates at test points.
718 pihat : (m,) array, series or dataframe, optional
719 Estimated propensity score at test points.
720 x_aux : (m, k) array, series or dataframe, optional
721 Additional covariates for the ``'aux'`` process at test points.
722 weights : (m,) array, series or dataframe, optional
723 Weights for the error variance on the test points.
724 transformed : bool, default True
725 If ``True``, return the prediction on the transformed outcome
726 :math:`\eta`, else the observable outcome :math:`y`.
727 samples : int, optional
728 If specified, indicates the number of samples to take from the
729 posterior. If not, return the mean and covariance matrix of the
730 posterior.
731 gvars : bool, default False
732 If ``True``, return the mean and covariance matrix of the posterior
733 as an array of `GVar` variables.
734 rng : numpy.random.Generator, optional
735 Random number generator, used if ``hp == 'sample'`` or ``samples``
736 is not `None`.
738 Returns
739 -------
740 If ``samples`` is `None` and ``gvars`` is `False` (default):
742 mean, cov : (m,) and (m, m) arrays
743 The mean and covariance matrix of the Normal posterior distribution
744 over the regression function or :math:`\eta` at the specified
745 locations.
747 If ``samples`` is `None` and ``gvars`` is `True`:
749 out : (m,) array of gvars
750 The same distribution represented as an array of `~gvar.GVar`
751 objects.
753 If ``samples`` is an integer:
755 sample : (samples, m) array
756 Posterior samples over either the regression function, :math:`\eta`,
757 or :math:`y`.
758 """
760 # check consistency of output choice
761 if samples is None: 761 ↛ 762line 761 didn't jump to line 762 because the condition on line 761 was never true1abcde
762 if not transformed:
763 raise ValueError('Posterior is required in analytical form '
764 '(samples=None) and in data space '
765 '(transformed=False), this is not possible as '
766 'the transformation model space -> data space '
767 'is arbitrary. Either sample the posterior or '
768 'get the result in model space.')
769 else:
770 if not transformed and not error: 770 ↛ 771line 770 didn't jump to line 771 because the condition on line 770 was never true1abcde
771 raise ValueError('Posterior is required in data space '
772 '(transformed=False) and without error term '
773 '(error=False), this is not possible as the '
774 'transformation model space -> data space '
775 'applies after adding the error.')
776 assert not gvars, 'can not represent posterior samples as gvars' 1abcde
778 # TODO allow exceptions to these rules when there are no transformations
779 # or the only transformation is 'standardize'.
781 # get hyperparameters
782 hp = self._gethp(hp, rng) 1abcde
784 # check presence of covariates is coherent
785 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1abcde
787 # convert all inputs to arrays compatible with jax to pass them to the
788 # compiled implementation
789 if z is not None: 1abcde
790 z = self._to_vector(z) 1a
791 pihat = self._to_vector(pihat) 1a
792 x_mu = self._to_structured(x_mu) 1a
793 if x_tau is not None: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true1a
794 x_tau = self._to_structured(x_tau)
795 if x_aux is not None: 795 ↛ 796line 795 didn't jump to line 796 because the condition on line 795 was never true1a
796 x_aux = self._to_structured(x_aux)
797 if weights is not None: 797 ↛ 798line 797 didn't jump to line 798 because the condition on line 797 was never true1abcde
798 weights = self._to_vector(weights)
800 # GP regression
801 mean, cov = self._pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw, bool(error)) 1abcde
803 # return Normal posterior moments
804 if samples is None: 804 ↛ 805line 804 didn't jump to line 805 because the condition on line 804 was never true1abcde
805 if gvars:
806 return gvar.gvar(mean, cov, fast=True)
807 else:
808 return mean, cov
810 # sample from posterior
811 sample = jnp.stack(list(_fastraniter.raniter(mean, cov, n=samples, rng=rng))) 1abcde
812 # TODO when I add vectorized sampling, use it here
813 if not transformed: 813 ↛ 815line 813 didn't jump to line 815 because the condition on line 813 was always true1abcde
814 sample = self._to_data(hp, sample) 1abcde
815 return sample 1abcde
817 # TODO the default should be something in data space, so with samples.
818 # If I handle the analyitical posterior through standardize, I could
819 # also make it without samples by default. Although I guess for
820 # whatever calculations samples are more convenient (just do the
821 # calculation on the samples.)
823 @functools.cached_property 1fabcde
824 def _pred(self): 1fabcde
826 @functools.partial(jax.jit, static_argnums=(8,)) 1abcde
827 def _pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw, error): 1abcde
828 gp = self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw) 1abcde
829 data = self.fit.data(hp, **gpfactorykw) 1abcde
830 if z is None: 1abcde
831 label = 'train' 1bcde
832 else:
833 label = 'test' 1a
834 if not error: 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true1abcde
835 label += 'mean'
836 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde
837 return outmean + hp.get('m', 0), outcov 1abcde
839 # TODO make everything pure and jit this per class instead of per
840 # instance
842 return _pred 1abcde
844 def from_data(self, y, *, hp='map', rng=None): 1fabcde
845 """
846 Transforms outcomes :math:`y` to the regression variable :math:`\\eta`.
848 Parameters
849 ----------
850 y : (n,) array
851 Outcomes.
852 hp : str or dict
853 The hyperparameters to use. If ``'map'`` (default), use the marginal
854 maximum a posteriori. If ``'sample'``, sample hyperparameters from
855 the posterior. If a dict, use the given hyperparameters.
856 rng : numpy.random.Generator, optional
857 Random number generator, used if ``hp == 'sample'``.
859 Returns
860 -------
861 eta : (n,) array
862 Transformed outcomes.
863 """
865 hp = self._gethp(hp, rng) 1bcde
866 return self._from_data(hp, y) 1bcde
868 def to_data(self, eta, *, hp='map', rng=None): 1fabcde
869 """
870 Convert the regression variable :math:`\\eta` to outcomes :math:`y`.
872 Parameters
873 ----------
874 eta : (n,) array
875 Transformed outcomes.
876 hp : str or dict
877 The hyperparameters to use. If ``'map'`` (default), use the marginal
878 maximum a posteriori. If ``'sample'``, sample hyperparameters from
879 the posterior. If a dict, use the given hyperparameters.
880 rng : numpy.random.Generator, optional
881 Random number generator, used if ``hp == 'sample'``.
883 Returns
884 -------
885 y : (n,) array
886 Outcomes.
887 """
889 hp = self._gethp(hp, rng) 1bcde
890 return self._to_data(hp, eta) 1bcde
892 @classmethod 1fabcde
893 def _to_structured(cls, x, *, check_numerical=True): 1fabcde
895 # convert to StructuredArray
896 if hasattr(x, 'columns'): 1abcde
897 x = _array.StructuredArray.from_dataframe(x) 1a
898 elif hasattr(x, 'to_numpy'): 898 ↛ 899line 898 didn't jump to line 899 because the condition on line 898 was never true1abcde
899 x = _array.StructuredArray.from_dict({
900 'f0' if x.name is None else x.name: x.to_numpy()
901 })
902 elif x.dtype.names is None: 1abcde
903 x = _array.unstructured_to_structured(x) 1bcde
904 else:
905 x = _array.StructuredArray(x) 1a
907 # check fields are numerical, for BART
908 if check_numerical: 908 ↛ 916line 908 didn't jump to line 916 because the condition on line 908 was always true1abcde
909 assert x.ndim == 1 1abcde
910 assert x.size > len(x.dtype) 1abcde
911 def check_numerical(path, dtype): 1abcde
912 if not numpy.issubdtype(dtype, numpy.number): 912 ↛ 913line 912 didn't jump to line 913 because the condition on line 912 was never true1abcde
913 raise TypeError(f'covariate `{path}` is not numerical')
914 cls._walk_dtype(x.dtype, check_numerical) 1abcde
916 return x 1abcde
918 @staticmethod 1fabcde
919 def _to_vector(x): 1fabcde
920 if hasattr(x, 'columns'): # dataframe 920 ↛ 921line 920 didn't jump to line 921 because the condition on line 920 was never true1abcde
921 x = x.to_numpy().squeeze(axis=1)
922 elif hasattr(x, 'to_numpy'): # series (dataframe column) 1abcde
923 x = x.to_numpy() 1a
924 x = jnp.asarray(x) 1abcde
925 if x.ndim != 1: 925 ↛ 926line 925 didn't jump to line 926 because the condition on line 925 was never true1abcde
926 raise ValueError(f'array is not 1d vector, ndim={x.ndim}')
927 return x 1abcde
929 @classmethod 1fabcde
930 def _walk_dtype(cls, dtype, task, path=None): 1fabcde
931 if dtype.names is None: 1abcde
932 task(path, dtype) 1abcde
933 else:
934 for name in dtype.names: 1abcde
935 subpath = name if path is None else path + ':' + name 1abcde
936 cls._walk_dtype(dtype[name], task, subpath) 1abcde
938 @staticmethod 1fabcde
939 def _toindices(x, splits): 1fabcde
940 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde
941 dtype = cast(x.dtype, ix.dtype) 1abcde
942 return _array.unstructured_to_structured(ix, dtype=dtype) 1abcde
944 def __repr__(self): 1fabcde
946 with _gvarext.gvar_format(): 1a
947 if hasattr(self.m, 'sdev'): 947 ↛ 948line 947 didn't jump to line 948 because the condition on line 947 was never true1a
948 m = str(self.m)
949 else:
950 m = f'{self.m:.3g}' 1a
952 n = self.fit.gpfactorykw['y'].size 1a
953 p_mu = _array._nd(self.fit.gpfactorykw['i_mu']['x'].dtype) 1a
954 x_tau = self.fit.gpfactorykw['i_tau'] 1a
955 x_aux = self.fit.gpfactorykw['x_aux'] 1a
957 out = f"""\ 1a
958Data:
959 n = {n}"""
961 if x_tau is None: 961 ↛ 965line 961 didn't jump to line 965 because the condition on line 961 was always true1a
962 out += f""" 1a
963 p = {p_mu}"""
964 else:
965 p_tau = _array._nd(x_tau['x'].dtype)
966 out += f"""
967 p_mu/tau = {p_mu}, {p_tau}"""
969 if x_aux is not None: 969 ↛ 970line 969 didn't jump to line 970 because the condition on line 969 was never true1a
970 p_aux = _array._nd(x_aux['x'].dtype)
971 out += f"""
972 p_aux = {p_aux}"""
974 out += f""" 1a
975Hyperparameter posterior:
976 m = {m}
977 z_0 = {self.z_0}
978 alpha_mu/tau = {self.alpha_mu} {self.alpha_tau}
979 beta_mu/tau = {self.beta_mu} {self.beta_tau}
980 lambda_mu/tau = {self.lambda_mu} {self.lambda_tau}"""
982 weights = self.fit.gpfactorykw['weights'] 1a
983 if weights is None: 983 ↛ 988line 983 didn't jump to line 988 because the condition on line 983 was always true1a
984 out += f""" 1a
985 sigma = {self.sigma}"""
987 else:
988 weights = numpy.array(weights) # to avoid jax taking over the ops
989 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights))
990 out += f"""
991 sqrt(mean(sigma^2/w)) = {avgsigma}
992 sigma = {self.sigma}"""
994 out += """ 1a
995Meaning of hyperparameters:
996 mu(x) = reference outcome level
997 tau(x) = effect of the treatment
998 z_0 in (0, 1): reference treatment level
999 z_0 -> 0: mu is the model of the untreated
1000 z_0 -> 1: mu is the model of the treated
1001 alpha in (0, 1)
1002 alpha -> 0: constant function
1003 alpha -> 1: no constraints on the function
1004 beta in (0, ∞)
1005 beta -> 0: no constraints on the function
1006 beta -> ∞: no interactions, f(x) = f1(x1) + f2(x2) + ...
1007 lambda in (0, ∞): standard deviation of function
1008 lambda small: confident extrapolation
1009 lambda large: conservative extrapolation
1010 sigma in (0, ∞): standard deviation of i.i.d. error"""
1012 return _utils.top_bottom_rule('BCF', out) 1a
1014 # TODO print user parameters, applying transformations. Copy the dict and use .pop() to remove the predefined params as they are printed.
1016 def _get_transf(self, *, transf, y, weights): 1fabcde
1018 from_datas = [] 1abcde
1019 to_datas = [] 1abcde
1020 hypers = {} 1abcde
1022 if transf is None: 1022 ↛ 1023line 1022 didn't jump to line 1023 because the condition on line 1022 was never true1abcde
1023 transf = []
1024 elif isinstance(transf, list): 1abcde
1025 name = lambda n: f'transf{i}_{n}' 1abcde
1026 else:
1027 name = lambda n: n 1bcde
1028 transf = [transf] 1bcde
1030 for i, tr in enumerate(transf): 1abcde
1032 hyper = {} 1abcde
1034 if not isinstance(tr, str): 1034 ↛ 1036line 1034 didn't jump to line 1036 because the condition on line 1034 was never true1abcde
1036 from_data, to_data = tr
1038 elif tr == 'standardize': 1abcde
1040 if i > 0: 1040 ↛ 1041line 1040 didn't jump to line 1041 because the condition on line 1040 was never true1abcde
1041 warnings.warn('standardization applied after other '
1042 'transformations: standardization always uses the '
1043 'initial data mean and standard deviation, so it may '
1044 'not work as intended')
1046 # It's not possible to overcome this limitation if one wants
1047 # to stick to transformations that act on one point at a
1048 # time to make them generalizable out of sample.
1050 if weights is None: 1050 ↛ 1054line 1050 didn't jump to line 1054 because the condition on line 1050 was always true1abcde
1051 loc = jnp.mean(y) 1abcde
1052 scale = jnp.std(y) 1abcde
1053 else:
1054 loc = jnp.average(y, weights=weights)
1055 scale = jnp.sqrt(jnp.average((y - loc) ** 2, weights=weights))
1057 def from_data(hp, y): 1abcde
1058 return (y - loc) / scale 1abcde
1059 def to_data(hp, eta): 1abcde
1060 return loc + scale * eta 1abcde
1062 elif tr == 'yeojohnson': 1062 ↛ 1071line 1062 didn't jump to line 1071 because the condition on line 1062 was always true1abcde
1064 def from_data(hp, y): 1abcde
1065 return yeojohnson(y, hp[name('lambda_yj')]) 1abcde
1066 def to_data(hp, eta): 1abcde
1067 return yeojohnson_inverse(eta, hp[name('lambda_yj')]) 1abcde
1068 hyper[name('lambda_yj')] = 2 * copula.beta(2, 2) 1abcde
1070 else:
1071 raise KeyError(tr)
1073 from_datas.append(from_data) 1abcde
1074 to_datas.append(to_data) 1abcde
1075 hypers.update(hyper) 1abcde
1077 if transf: 1077 ↛ 1087line 1077 didn't jump to line 1087 because the condition on line 1077 was always true1abcde
1078 def from_data(hp, y): 1abcde
1079 for fd in from_datas: 1abcde
1080 y = fd(hp, y) 1abcde
1081 return y 1abcde
1082 def to_data(hp, eta): 1abcde
1083 for td in reversed(to_datas): 1abcde
1084 eta = td(hp, eta) 1abcde
1085 return eta 1abcde
1086 else:
1087 from_data = lambda hp, y: y
1088 to_data = lambda hp, eta: eta
1090 from_data_grad = _jaxext.elementwise_grad(from_data, 1) 1abcde
1091 def loss(hp): 1abcde
1092 return -jnp.sum(jnp.log(from_data_grad(hp, y))) 1abcde
1094 hypers = copula.makedict(hypers) 1abcde
1096 return from_data, to_data, loss, hypers 1abcde
1098def yeojohnson(x, lmbda): 1fabcde
1099 """ Yeo-Johnson transformation with lamda != 0, 2 """
1100 return jnp.where( 1abcde
1101 x >= 0,
1102 (jnp.power(x + 1, lmbda) - 1) / lmbda,
1103 -((jnp.power(-x + 1, 2 - lmbda) - 1) / (2 - lmbda))
1104 )
1106 # TODO
1107 # - rewrite the cases with expm1, log1p, etc. to make them accurate
1108 # - split the cases into lambda 0/2
1109 # - make custom_jvps for the singular points to define derivatives w.r.t.
1110 # lambda even though it does not appear in the expression
1111 # - add unit tests that check gradients with finite differences
1113def yeojohnson_inverse(y, lmbda): 1fabcde
1114 return jnp.where( 1abcde
1115 y >= 0,
1116 jnp.power(y * lmbda + 1, 1 / lmbda) - 1,
1117 -jnp.power(-(2 - lmbda) * y + 1, 1 / (2 - lmbda)) + 1
1118 )