Coverage for src/lsqfitgp/bayestree/_bcf.py: 80%

360 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/bayestree/_bcf.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1fabcde

21import warnings 1fabcde

22 

23import numpy 1fabcde

24from scipy import stats 1fabcde

25from jax import numpy as jnp 1fabcde

26import jax 1fabcde

27import gvar 1fabcde

28 

29from .. import copula 1fabcde

30from .. import _kernels 1fabcde

31from .. import _fit 1fabcde

32from .. import _array 1fabcde

33from .. import _GP 1fabcde

34from .. import _fastraniter 1fabcde

35from .. import _jaxext 1fabcde

36from .. import _gvarext 1fabcde

37from .. import _utils 1fabcde

38 

39# TODO add methods or options to do causal inference stuff, e.g., impute missing 

40# outcomes, or ate, att, cate, catt, sate, satt. Remember that the effect may 

41# also depend on aux. See bartCause, possibly copy its naming. 

42 

43def _recursive_cast(dtype, default, mapping): 1fabcde

44 if dtype in mapping: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1abcde

45 return mapping[dtype] 

46 elif dtype.names is not None: 1abcde

47 return numpy.dtype([ 1abcde

48 (name, _recursive_cast(dtype[name], default, mapping)) 

49 for name in dtype.names 

50 ]) 

51 elif dtype.subdtype is not None: 51 ↛ 53line 51 didn't jump to line 53 because the condition on line 51 was never true1abcde

52 # note: has names => does not have subdtype 

53 return numpy.dtype((_recursive_cast(dtype.base, default, mapping), dtype.shape)) 

54 elif default is None: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true1abcde

55 return dtype 

56 else: 

57 return default 1abcde

58 

59def cast(dtype, default, mapping={}): 1fabcde

60 """ 

61 Recursively cast a numpy data type. 

62 

63 Parameters 

64 ---------- 

65 dtype : dtype 

66 The data type to cast. 

67 default : dtype or None 

68 The leaf fields of `dtype` are casted to `default`, which can be 

69 structured, unless they appear in `mapping`. If None, dtypes not in 

70 `mapping` are left unchanged. 

71 mapping : dict 

72 A dictionary from dtypes to dtypes, indicating specific casting rules. 

73 The dtypes can be structured, a match of a structured dtype takes 

74 precedence over matches in its leaves, and the converted dtype is not 

75 further searched for matches. 

76 

77 Returns 

78 ------- 

79 casted_dtype : dtype 

80 The casted version of `dtype`. May not have the same structure if 

81 `mapping` contains structured dtypes. 

82 """ 

83 mapping = {numpy.dtype(k): numpy.dtype(v) for k, v in mapping.items()} 1abcde

84 default = None if default is None else numpy.dtype(default) 1abcde

85 return _recursive_cast(numpy.dtype(dtype), default, mapping) 1abcde

86 

87 # TODO 

88 # - move this to generic utils 

89 # - make unit tests 

90 

91class bcf: 1fabcde

92 

93 def __init__(self, *, 1fabcde

94 y, 

95 z, 

96 x_mu, 

97 x_tau=None, 

98 pihat, 

99 include_pi='mu', 

100 weights=None, 

101 fitkw={}, 

102 kernelkw_mu={}, 

103 kernelkw_tau={}, 

104 marginalize_mean=True, 

105 gpaux=None, 

106 x_aux=None, 

107 otherhp={}, 

108 transf='standardize', 

109 ): 

110 r""" 

111 Nonparametric Bayesian regression with a GP version of BCF. 

112 

113 BCF (Bayesian Causal Forests) is a regression method for observational 

114 causal inference studies introduced in [1]_ based on a pair of BART 

115 models. 

116 

117 This class evaluates a Gaussian process regression with a kernel which 

118 accurately approximates BCF in the infinite trees limit of each BART 

119 model. The hyperparameters are optimized to their marginal MAP. 

120 

121 The model is (loosely, see notes below) :math:`y = \mu(x) + z\tau(x)`, 

122 so :math:`\tau(x)` is the expected causal effect of :math:`z` on 

123 :math:`y` at location :math:`x`. 

124 

125 Parameters 

126 ---------- 

127 y : (n,) array, series or dataframe 

128 Outcome. 

129 z : (n,) array, series or dataframe 

130 Binary treatment status: 0 control group, 1 treatment group. 

131 x_mu : (n, p) array, series or dataframe 

132 Covariates for the :math:`\mu` model. 

133 x_tau : (n, q) array, series or dataframe, optional 

134 Covariates for the :math:`\tau` model. If not specified, use `x_mu`. 

135 pihat : (n,) array, series or dataframe 

136 Estimated propensity score, i.e., P(Z=1|X). 

137 include_pi : {'mu', 'tau', 'both'}, optional 

138 Whether to include the propensity score in the :math:`\mu` model, 

139 the :math:`\tau` model, or both. Default is ``'mu'``. 

140 weights : (n,) array, series or dataframe 

141 Weights used to rescale the error variance (as 1 / weight). 

142 fitkw : dict 

143 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides 

144 the defaults. 

145 kernelkw_mu, kernelkw_tau : dict 

146 Additional arguments passed to `~lsqfitgp.BART` for each model, 

147 overrides the defaults. 

148 marginalize_mean : bool 

149 If True (default), marginalize the intercept of the model. 

150 gpaux : callable, optional 

151 If specified, this function is called with a pair ``(hp, gp)``, 

152 where ``hp`` is a dictionary of hyperparameters, and ``gp`` is a 

153 `~lsqfitgp.GP` object under construction, and is expected to return 

154 a modified ``gp`` with a new process named ``'aux'`` defined with 

155 `~lsqfitgp.GP.defproc` or similar. The process is added to the 

156 regression model. The input to the process is a structured array 

157 with fields: 

158 

159 'train' : bool 

160 Indicates whether the data is training set (the one passed on 

161 initialization) or test set (the one passed to `pred` or `gp`). 

162 'i' : int 

163 Index of the flattened array. 

164 'z' : int 

165 Treatment status. 

166 'mu', 'tau' : structured 

167 The values in `x_mu` and `x_tau`, converted to indices according 

168 to the BART grids. Where `pihat` has been added, there are two 

169 subfields: ``'x'`` which contains the covariates, and 

170 ``'pihat'``, the latter expressed in indices as well. 

171 'pihat' : float 

172 The `pihat` argument. Contrary to the subfield included under 

173 ``'mu'`` and/or ``'tau'``, this field contains the original 

174 values. 

175 'aux' : structured 

176 The values in `x_aux`, if specified. 

177  

178 x_aux : (n, k) array, series or dataframe, optional 

179 Additional covariates for the ``'aux'`` process. 

180 otherhp : dictionary of gvar 

181 A dictionary with the prior of arbitrary additional hyperpameters, 

182 intended to be used by ``gpaux`` or ``transf``. 

183 transf : (list of) str or pair of callable 

184 Data transformation. Either a string indicating a pre-defined 

185 transformation, or a pair ``(from_data, to_data)``, two functions 

186 with signatures ``from_data(hp, y) -> eta`` and ``to_data(hp, eta) 

187 -> y``, where ``eta`` is the value to which the model is fit, and 

188 ``hp`` is the dictionary of hyperparameters. The functions must be 

189 ufuncs and one the inverse of the other w.r.t. the second parameter. 

190 ``from_data`` must be derivable with `jax` w.r.t. ``y``. 

191 

192 If a list of such specifications is provided, the transformations 

193 are applied in order, with the first one being the outermost, i.e., 

194 the one applied first to the data. 

195 

196 If a transformation uses additional hyperparameters, either 

197 predefined automatically or passed by the user through `otherhp`, 

198 they are inferred with the rest of the hyperparameters. 

199 

200 The pre-defined transformations are: 

201 

202 'standardize' (default) 

203 eta = (y - mean(train_y)) / sdev(train_y) 

204 'yeojohnson' 

205 The Yeo-Johnson transformation [2]_ to reduce skewness. The 

206 :math:`\lambda` parameter is bounded in :math:`(0, 2)` 

207 for implementation convenience, this restriction may be lifted 

208 in future versions. 

209  

210 Notes 

211 ----- 

212 The regression model is: 

213 

214 .. math:: 

215 \eta_i = g(y_i; \ldots) &= m + {} \\ 

216 &\phantom{{}={}} + 

217 \lambda_\mu 

218 \mu(\mathbf x^\mu_i, \hat\pi_i?) + {} \\ 

219 &\phantom{{}={}} + 

220 \lambda_\tau 

221 \tau(\mathbf x^\tau_i, \hat\pi_i?) (z_i - z_0) + {} \\ 

222 &\phantom{{}={}} + 

223 \mathrm{aux}(i, z_i, \mathbf x^\mu_i, \mathbf x^\tau_i, 

224 \hat\pi_i, \mathbf x^\text{aux}_i) + {} \\ 

225 &\phantom{{}={}} + 

226 \varepsilon_i, \\ 

227 \varepsilon_i &\sim 

228 N(0, \sigma^2 / w_i), \\ 

229 m &\sim N(0, 1), \\ 

230 \log \sigma^2 &\sim N(\log\bar w, 4), \\ 

231 \lambda_\mu 

232 &\sim \mathrm{HalfCauchy}(2), \\ 

233 \lambda_\tau 

234 &\sim \mathrm{HalfNormal}(1.48), \\ 

235 \mu &\sim \mathrm{GP}(0, 

236 \mathrm{BART}(\alpha_\mu, \beta_\mu) ), \\ 

237 \tau &\sim \mathrm{GP}(0, 

238 \mathrm{BART}(\alpha_\tau, \beta_\tau) ), \\ 

239 \mathrm{aux} & \sim \mathrm{GP}(0, \text{<user defined>}), \\ 

240 \alpha_\mu, \alpha_\tau &\sim \mathrm{Beta}(2, 1), \\ 

241 \beta_\mu, \beta_\tau &\sim \mathrm{InvGamma}(1, 1), \\ 

242 z_0 &\sim U(0, 1), 

243 

244 To make the inference, :math:`(\mu, \tau, \boldsymbol\varepsilon, m, 

245 \mathrm{aux})` are marginalized analytically, and the marginal posterior 

246 mode of 

247 :math:`(\sigma, \lambda_*, \alpha_*, \beta_*, z_0, \ldots)` is found by 

248 numerical minimization, after transforming them to express their prior 

249 as a Gaussian copula. Their marginal posterior covariance matrix is 

250 estimated with an approximation of the hessian inverse. See 

251 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize 

252 this procedure. 

253 

254 The tree splitting grid of the BART kernel is set using quantiles of the 

255 observed covariates. This corresponds to settings ``usequants=True``, 

256 ``numcut=inf`` in the R packages BayesTree and BART. Use the parameters 

257 `kernelkw_mu` and `kernelkw_tau` to customize the grids. 

258 

259 The difference between the regression model evaluated at :math:`Z=1` vs. 

260 :math:`Z=0` can be interpreted as the causal effect :math:`Z \rightarrow 

261 Y` if the unconfoundedness assumption is made: 

262 

263 .. math:: 

264 \{Y(Z=0), Y(Z=1)\} \perp\!\!\!\perp Z \mid X. 

265 

266 In practical terms, this holds when: 

267 

268 1) :math:`X` are pre-treatment variables, i.e., they represent 

269 quantities causally upstream of :math:`Z`. 

270 

271 2) :math:`X` are sufficient to adjust for all common causes of 

272 :math:`Z` and :math:`Y`, such that the only remaining difference 

273 is the causal effect and not just a correlation. 

274 

275 Here :math:`X` consists in `x_tau`, `x_mu` and `x_aux`. However these 

276 arrays may also be used to pass "technical" values used to set up the 

277 model, that do not satisfy the uncounfoundedness assumption, if you know 

278 what you are doing. 

279 

280 Attributes 

281 ---------- 

282 m : float or gvar 

283 The prior mean :math:`m`. 

284 sigma : gvar 

285 The error term standard deviation :math:`\sigma`. If there are 

286 weights, the sdev for each unit is obtained dividing ``sigma`` by 

287 sqrt(weight). 

288 alpha_mu, alpha_tau : gvar 

289 The numerator of the tree spawn probability :math:`\alpha_*` (named 

290 ``base`` in R bcf). 

291 beta_mu, beta_tau : gvar 

292 The depth exponent of the tree spawn probability :math:`\beta_*` 

293 (named ``power`` in R bcf). 

294 lambda_mu, lambda_tau : gvar 

295 The prior standard deviation :math:`\lambda_*`. 

296 z_0 : gvar 

297 The treatment coding parameter. 

298 fit : empbayes_fit 

299 The hyperparameters fit object. 

300 

301 Methods 

302 ------- 

303 gp : 

304 Create a GP object. 

305 data : 

306 Creates the dictionary to be passed to `GP.pred` to represent data. 

307 pred : 

308 Evaluate the regression function at given locations. 

309 from_data : 

310 Convert :math:`y` to :math:`\eta`. 

311 to_data : 

312 Convert :math:`\eta` to :math:`y`. 

313 

314 See also 

315 -------- 

316 lsqfitgp.BART 

317  

318 References 

319 ---------- 

320 .. [1] P. Richard Hahn, Jared S. Murray, Carlos M. Carvalho "Bayesian 

321 Regression Tree Models for Causal Inference: Regularization, 

322 Confounding, and Heterogeneous Effects (with Discussion)," Bayesian 

323 Analysis 15(3), 965-1056, September 2020, 

324 https://doi.org/10.1214/19-BA1195 

325 .. [2] Yeo, In-Kwon; Johnson, Richard A. (2000). "A New Family of Power 

326 Transformations to Improve Normality or Symmetry". Biometrika. 87 

327 (4): 954–959. https://doi.org/10.1093/biomet/87.4.954 

328 """ 

329 

330 # convert covariates to StructuredArray 

331 x_mu = self._to_structured(x_mu) 1abcde

332 if x_tau is not None: 1abcde

333 x_tau = self._to_structured(x_tau) 1bcde

334 assert x_tau.shape == x_mu.shape 1bcde

335 if x_aux is not None: 1abcde

336 x_aux = self._to_structured(x_aux) 1bcde

337 assert x_aux.shape == x_mu.shape 1bcde

338 

339 # convert outcomes, treatment, propensity score, weights to 1d arrays 

340 y = self._to_vector(y) 1abcde

341 z = self._to_vector(z) 1abcde

342 pihat = self._to_vector(pihat) 1abcde

343 assert y.shape == z.shape == pihat.shape == x_mu.shape 1abcde

344 if weights is not None: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true1abcde

345 weights = self._to_vector(weights) 

346 assert weights.shape == x_mu.shape 

347 

348 # check include_pi 

349 if include_pi not in ('mu', 'tau', 'both'): 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true1abcde

350 raise KeyError(f'invalid value include_pi={include_pi!r}') 

351 self._include_pi = include_pi 1abcde

352 

353 # add pihat to covariates 

354 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1abcde

355 

356 # grid and indices 

357 splits_mu = _kernels.BART.splits_from_coord(x_mu) 1abcde

358 i_mu = self._toindices(x_mu, splits_mu) 1abcde

359 if x_tau is None: 1abcde

360 splits_tau = splits_mu 1abcde

361 i_tau = None 1abcde

362 else: 

363 splits_tau = _kernels.BART.splits_from_coord(x_tau) 1bcde

364 i_tau = self._toindices(x_tau, splits_tau) 1bcde

365 

366 # get functions for data transformation 

367 from_data, to_data, transfloss, transfhp = self._get_transf( 1abcde

368 transf=transf, weights=weights, y=y) 

369 

370 # scale of error variance 

371 logsigma2_loc = 0 if weights is None else numpy.log(jnp.mean(weights)) 1abcde

372 

373 # prior on hyperparams 

374 hyperprior = copula.makedict({ 1abcde

375 'm': gvar.gvar(0, 1), 

376 'sigma^2': copula.lognorm(logsigma2_loc, 2), 

377 'lambda_mu': copula.halfcauchy(2), 

378 'lambda_tau': copula.halfnorm(1.48), 

379 'alpha_mu': copula.beta(2, 1), 

380 'alpha_tau': copula.beta(2, 1), 

381 'beta_mu': copula.invgamma(1, 1), 

382 'beta_tau': copula.invgamma(1, 1), 

383 'z_0': copula.uniform(0, 1), 

384 }) 

385 

386 # remove explicit mean parameter if it's baked into the Gaussian process 

387 if marginalize_mean: 1abcde

388 hyperprior.pop('m') 1abcde

389 

390 # add data transformation and user hyperparameters 

391 def update_hyperparams(new, newname, raises): 1abcde

392 new = gvar.BufferDict(new) 1abcde

393 for key in new.all_keys(): 1abcde

394 if hyperprior.has_dictkey(key): 394 ↛ 395line 394 didn't jump to line 395 because the condition on line 394 was never true1abcde

395 message = f'{newname} hyperparameter {key!r} overrides existing one' 

396 if raises: 

397 raise ValueError(message) 

398 else: 

399 warnings.warn(message) 

400 hyperprior.update(new) 1abcde

401 update_hyperparams(transfhp, 'data transformation', True) 1abcde

402 # the hypers handed by _get_transf are not allowed to override 

403 update_hyperparams(otherhp, 'user', False) 1abcde

404 

405 # GP factory 

406 def gpfactory(hp, *, z, i_mu, i_tau, pihat, x_aux, weights, 1abcde

407 splits_mu, splits_tau, **_): 

408 

409 # TODO maybe I should pass kernelkw_* as arguments, but they may not 

410 # be jittable. I need jitkw in empbayes_fit for that. 

411 

412 kw_overridable = dict( 1abcde

413 maxd=10, 

414 reset=[2, 4, 6, 8], 

415 intercept=False, 

416 ) 

417 kw_not_overridable = dict(indices=True) 1abcde

418 

419 gp = _GP.GP(checkpos=False, checksym=False, solver='chol') 1abcde

420 

421 for name, kernelkw in dict(mu=kernelkw_mu, tau=kernelkw_tau).items(): 1abcde

422 kw = dict( 1abcde

423 alpha=hp[f'alpha_{name}'], 

424 beta=hp[f'beta_{name}'], 

425 dim=name, 

426 splits=eval(f'splits_{name}'), 

427 **kw_overridable, 

428 ) 

429 kw.update(kernelkw) 1abcde

430 kernel = _kernels.BART(**kw, **kw_not_overridable) 1abcde

431 kernel *= hp[f'lambda_{name}'] ** 2 1abcde

432 

433 gp = gp.defproc(name, kernel) 1abcde

434 

435 if 'm' in hp: 1abcde

436 kernel_mean = 0 * _kernels.Constant() 1bcde

437 else: 

438 kernel_mean = _kernels.Constant() 1abcde

439 gp = gp.defproc('m', kernel_mean) 1abcde

440 

441 if gpaux is None: 1abcde

442 gp = gp.defproc('aux', 0 * _kernels.Constant()) 1abcde

443 else: 

444 gp = gpaux(hp, gp) 1bcde

445 

446 gp = gp.deflintransf( 1abcde

447 gp.DefaultProcess, 

448 lambda m, mu, tau, aux: lambda x: 

449 m(x) + mu(x) + tau(x) * (x['z'] - hp['z_0']) + aux(x), 

450 ['m', 'mu', 'tau', 'aux'], 

451 ) 

452 

453 x = self._join_points(True, z, i_mu, i_tau, pihat, x_aux) 1abcde

454 gp = gp.addx(x, 'trainmean') 1abcde

455 errcov = self._error_cov(hp, weights, x) 1abcde

456 return (gp 1abcde

457 .addcov(errcov, 'trainnoise') 

458 .addtransf({'trainmean': 1, 'trainnoise': 1}, 'train') 

459 ) 

460 

461 # data factory 

462 def data(hp, *, y, **_): 1abcde

463 return {'train': from_data(hp, y) - hp.get('m', 0)} 1abcde

464 

465 # fit hyperparameters 

466 gpkw = dict( 1abcde

467 y=y, 

468 z=z, 

469 i_mu=i_mu, 

470 i_tau=i_tau, 

471 pihat=pihat, 

472 x_aux=x_aux, 

473 weights=weights, 

474 splits_mu=splits_mu, 

475 splits_tau=splits_tau, 

476 ) 

477 options = dict( 1abcde

478 verbosity=3, 

479 minkw=dict(method='l-bfgs-b', options=dict(maxls=4, maxiter=100)), 

480 mlkw=dict(epsrel=0), 

481 forward=True, 

482 gpfactorykw=gpkw, 

483 additional_loss=transfloss, 

484 ) 

485 options.update(fitkw) 1abcde

486 fit = _fit.empbayes_fit(hyperprior, gpfactory, data, **options) 1abcde

487 

488 # extract hyperparameters from minimization result 

489 self.m = fit.p.get('m', 0) 1abcde

490 self.sigma = gvar.sqrt(fit.p['sigma^2']) 1abcde

491 self.lambda_mu = fit.p['lambda_mu'] 1abcde

492 self.lambda_tau = fit.p['lambda_tau'] 1abcde

493 self.alpha_mu = fit.p['alpha_mu'] 1abcde

494 self.alpha_tau = fit.p['alpha_tau'] 1abcde

495 self.beta_mu = fit.p['beta_mu'] 1abcde

496 self.beta_tau = fit.p['beta_tau'] 1abcde

497 self.z_0 = fit.p['z_0'] 1abcde

498 

499 # save other attributes 

500 self.fit = fit 1abcde

501 self._from_data = from_data 1abcde

502 self._to_data = to_data 1abcde

503 

504 def _append_pihat(self, x_mu, x_tau, pihat): 1fabcde

505 ip = self._include_pi 1abcde

506 if ip == 'mu' or ip == 'both': 1abcde

507 x_mu = _array.StructuredArray.from_dict(dict( 1abcde

508 x=x_mu, 

509 pihat=pihat, 

510 )) 

511 if x_tau is not None and (ip == 'tau' or ip == 'both'): 1abcde

512 x_tau = _array.StructuredArray.from_dict(dict( 1bcde

513 x=x_tau, 

514 pihat=pihat, 

515 )) 

516 return x_mu, x_tau 1abcde

517 

518 @staticmethod 1fabcde

519 def _join_points(train, z, i_mu, i_tau, pihat, x_aux): 1fabcde

520 """ join covariates into a single StructuredArray """ 

521 columns = dict( 1abcde

522 train=jnp.broadcast_to(bool(train), z.shape), 

523 i=jnp.arange(z.size).reshape(z.shape), 

524 z=z, 

525 mu=i_mu, 

526 tau=i_mu if i_tau is None else i_tau, 

527 pihat=pihat, 

528 ) 

529 if x_aux is not None: 1abcde

530 columns.update(aux=x_aux) 1bcde

531 return _array.StructuredArray.from_dict(columns) 1abcde

532 

533 @staticmethod 1fabcde

534 def _error_cov(hp, weights, x): 1fabcde

535 """ fill error covariance matrix """ 

536 if weights is None: 536 ↛ 539line 536 didn't jump to line 539 because the condition on line 536 was always true1abcde

537 error_var = jnp.broadcast_to(hp['sigma^2'], len(x)) 1abcde

538 else: 

539 error_var = hp['sigma^2'] / weights 

540 return jnp.diag(error_var) 1abcde

541 

542 def _gethp(self, hp, rng): 1fabcde

543 if not isinstance(hp, str): 543 ↛ 544line 543 didn't jump to line 544 because the condition on line 543 was never true1abcde

544 return hp 

545 elif hp == 'map': 1abcde

546 return self.fit.pmean 1abcde

547 elif hp == 'sample': 547 ↛ 550line 547 didn't jump to line 550 because the condition on line 547 was always true1abcde

548 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng) 1abcde

549 else: 

550 raise KeyError(hp) 

551 

552 def gp(self, *, hp='map', z=None, x_mu=None, x_tau=None, pihat=None, 1fabcde

553 x_aux=None, weights=None, rng=None): 

554 """ 

555 Create a Gaussian process with the fitted hyperparameters. 

556 

557 Parameters 

558 ---------- 

559 hp : str or dict 

560 The hyperparameters to use. If ``'map'`` (default), use the marginal 

561 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

562 the posterior. If a dict, use the given hyperparameters. 

563 z : (m,) array, series or dataframe, optional 

564 Treatment status at test points. If specified, also `x_mu`, `pihat`, 

565 `x_tau` and `x_aux` (the latter two if and only also specified at 

566 initialization) must be specified. 

567 x_mu : (m, p) array, series or dataframe, optional 

568 Control model covariates at test points. 

569 x_tau : (m, q) array, series or dataframe, optional 

570 Moderating model covariates at test points. 

571 pihat : (m,) array, series or dataframe, optional 

572 Estimated propensity score at test points. 

573 x_aux : (m, k) array, series or dataframe, optional 

574 Additional covariates for the ``'aux'`` process. 

575 weights : (m,) array, series or dataframe, optional 

576 Weights for the error variance on the test points. 

577 rng : numpy.random.Generator, optional 

578 Random number generator, used if ``hp == 'sample'``. 

579 

580 Returns 

581 ------- 

582 gp : GP 

583 A centered Gaussian process object. To add the mean, use the `m` 

584 attribute of the `bcf` object. The keys of the GP are ``'@mean'``, 

585 ``'@noise'``, and ``'@'``, where the "@" stands either for 'train' 

586 or 'test', and @ = @mean + @noise. 

587 

588 This Gaussian process is defined on the transformed data ``eta``. 

589 """ 

590 

591 hp = self._gethp(hp, rng) 

592 return self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw) 

593 

594 def _gp(self, hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw): 1fabcde

595 """ 

596 Internal function to create the GP object. This function must work 

597 both if the arguments are user-provided and need to be checked and 

598 converted to standard format, or if they are traced jax values. 

599 """ 

600 

601 # create GP object 

602 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde

603 

604 # add test points 

605 if z is not None: 1abcde

606 

607 # check presence/absence of arguments is coherent 

608 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1a

609 

610 # check treatment and propensity score 

611 z = self._to_vector(z) 1a

612 pihat = self._to_vector(pihat) 1a

613 assert pihat.shape == z.shape 1a

614 

615 # check weights 

616 if weights is not None: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true1a

617 weights = self._to_vector(weights) 

618 assert weights.shape == z.shape 

619 

620 # add propensity score to covariates 

621 x_mu = self._to_structured(x_mu) 1a

622 assert x_mu.shape == z.shape 1a

623 if x_tau is not None: 623 ↛ 624line 623 didn't jump to line 624 because the condition on line 623 was never true1a

624 x_tau = self._to_structured(x_tau) 

625 assert x_tau.shape == z.shape 

626 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1a

627 

628 # convert covariates to indices 

629 i_mu = self._toindices(x_mu, gpfactorykw['splits_mu']) 1a

630 assert i_mu.dtype == gpfactorykw['i_mu'].dtype 1a

631 if x_tau is not None: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true1a

632 i_tau = self._toindices(x_tau, gpfactorykw['splits_tau']) 

633 assert i_tau.dtype == gpfactorykw['i_tau'].dtype 

634 else: 

635 i_tau = None 1a

636 

637 # check auxiliary points 

638 if x_aux is not None: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true1a

639 x_aux = self._to_structured(x_aux) 

640 

641 # add test points 

642 x = self._join_points(False, z, i_mu, i_tau, pihat, x_aux) 1a

643 gp = gp.addx(x, 'testmean') 1a

644 errcov = self._error_cov(hp, weights, x) 1a

645 gp = (gp 1a

646 .addcov(errcov, 'testnoise') 

647 .addtransf({'testmean': 1, 'testnoise': 1}, 'test') 

648 ) 

649 

650 return gp 1abcde

651 

652 def _check_coherent_covariates(self, z, x_mu, x_tau, pihat, x_aux): 1fabcde

653 if z is None: 1abcde

654 assert x_mu is None 1bcde

655 assert x_tau is None 1bcde

656 assert pihat is None 1bcde

657 assert x_aux is None 1bcde

658 else: 

659 assert x_mu is not None 1a

660 assert pihat is not None 1a

661 train_tau = self.fit.gpfactorykw['i_tau'] 1a

662 if x_tau is None: 662 ↛ 665line 662 didn't jump to line 665 because the condition on line 662 was always true1a

663 assert train_tau is None 1a

664 else: 

665 assert train_tau is not None 

666 train_aux = self.fit.gpfactorykw['x_aux'] 1a

667 if x_aux is None: 667 ↛ 670line 667 didn't jump to line 670 because the condition on line 667 was always true1a

668 assert train_aux is None 1a

669 else: 

670 assert train_aux is not None 

671 

672 def data(self, *, hp='map', rng=None): 1fabcde

673 """ 

674 Get the data to be passed to `GP.pred` on a GP object returned by `gp`. 

675 

676 Parameters 

677 ---------- 

678 hp : str or dict 

679 The hyperparameters to use. If ``'map'`` (default), use the marginal 

680 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

681 the posterior. If a dict, use the given hyperparameters. 

682 rng : numpy.random.Generator, optional 

683 Random number generator, used if ``hp == 'sample'``. 

684 

685 Returns 

686 ------- 

687 data : dict 

688 A dictionary representing ``eta`` in the format required by the 

689 `GP.pred` method. 

690 """ 

691 

692 hp = self._gethp(hp, rng) 

693 return self.fit.data(hp, **self.fit.gpfactorykw) 

694 

695 def pred(self, *, hp='map', error=False, z=None, x_mu=None, x_tau=None, 1fabcde

696 pihat=None, x_aux=None, weights=None, transformed=True, samples=None, 

697 gvars=False, rng=None): 

698 r""" 

699 Predict the transformed outcome at given locations. 

700 

701 Parameters 

702 ---------- 

703 hp : str or dict 

704 The hyperparameters to use. If ``'map'`` (default), use the marginal 

705 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

706 the posterior. If a dict, use the given hyperparameters. 

707 error : bool, default False 

708 If ``False``, make a prediction for the latent mean. If ``True``, 

709 add the error term. 

710 z : (m,) array, series or dataframe, optional 

711 Treatment status at test points. If specified, also `x_mu`, `pihat`, 

712 `x_tau` and `x_aux` (the latter two if and only also specified at 

713 initialization) must be specified. 

714 x_mu : (m, p) array, series or dataframe, optional 

715 :math:`\mu` model covariates at test points. 

716 x_tau : (m, q) array, series or dataframe, optional 

717 :math:`\tau` model covariates at test points. 

718 pihat : (m,) array, series or dataframe, optional 

719 Estimated propensity score at test points. 

720 x_aux : (m, k) array, series or dataframe, optional 

721 Additional covariates for the ``'aux'`` process at test points. 

722 weights : (m,) array, series or dataframe, optional 

723 Weights for the error variance on the test points. 

724 transformed : bool, default True 

725 If ``True``, return the prediction on the transformed outcome 

726 :math:`\eta`, else the observable outcome :math:`y`. 

727 samples : int, optional 

728 If specified, indicates the number of samples to take from the 

729 posterior. If not, return the mean and covariance matrix of the 

730 posterior. 

731 gvars : bool, default False 

732 If ``True``, return the mean and covariance matrix of the posterior 

733 as an array of `GVar` variables. 

734 rng : numpy.random.Generator, optional 

735 Random number generator, used if ``hp == 'sample'`` or ``samples`` 

736 is not `None`. 

737 

738 Returns 

739 ------- 

740 If ``samples`` is `None` and ``gvars`` is `False` (default): 

741 

742 mean, cov : (m,) and (m, m) arrays 

743 The mean and covariance matrix of the Normal posterior distribution 

744 over the regression function or :math:`\eta` at the specified 

745 locations. 

746 

747 If ``samples`` is `None` and ``gvars`` is `True`: 

748 

749 out : (m,) array of gvars 

750 The same distribution represented as an array of `~gvar.GVar` 

751 objects. 

752 

753 If ``samples`` is an integer: 

754 

755 sample : (samples, m) array 

756 Posterior samples over either the regression function, :math:`\eta`, 

757 or :math:`y`. 

758 """ 

759 

760 # check consistency of output choice 

761 if samples is None: 761 ↛ 762line 761 didn't jump to line 762 because the condition on line 761 was never true1abcde

762 if not transformed: 

763 raise ValueError('Posterior is required in analytical form ' 

764 '(samples=None) and in data space ' 

765 '(transformed=False), this is not possible as ' 

766 'the transformation model space -> data space ' 

767 'is arbitrary. Either sample the posterior or ' 

768 'get the result in model space.') 

769 else: 

770 if not transformed and not error: 770 ↛ 771line 770 didn't jump to line 771 because the condition on line 770 was never true1abcde

771 raise ValueError('Posterior is required in data space ' 

772 '(transformed=False) and without error term ' 

773 '(error=False), this is not possible as the ' 

774 'transformation model space -> data space ' 

775 'applies after adding the error.') 

776 assert not gvars, 'can not represent posterior samples as gvars' 1abcde

777 

778 # TODO allow exceptions to these rules when there are no transformations 

779 # or the only transformation is 'standardize'. 

780 

781 # get hyperparameters 

782 hp = self._gethp(hp, rng) 1abcde

783 

784 # check presence of covariates is coherent 

785 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1abcde

786 

787 # convert all inputs to arrays compatible with jax to pass them to the 

788 # compiled implementation 

789 if z is not None: 1abcde

790 z = self._to_vector(z) 1a

791 pihat = self._to_vector(pihat) 1a

792 x_mu = self._to_structured(x_mu) 1a

793 if x_tau is not None: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true1a

794 x_tau = self._to_structured(x_tau) 

795 if x_aux is not None: 795 ↛ 796line 795 didn't jump to line 796 because the condition on line 795 was never true1a

796 x_aux = self._to_structured(x_aux) 

797 if weights is not None: 797 ↛ 798line 797 didn't jump to line 798 because the condition on line 797 was never true1abcde

798 weights = self._to_vector(weights) 

799 

800 # GP regression 

801 mean, cov = self._pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw, bool(error)) 1abcde

802 

803 # return Normal posterior moments 

804 if samples is None: 804 ↛ 805line 804 didn't jump to line 805 because the condition on line 804 was never true1abcde

805 if gvars: 

806 return gvar.gvar(mean, cov, fast=True) 

807 else: 

808 return mean, cov 

809 

810 # sample from posterior 

811 sample = jnp.stack(list(_fastraniter.raniter(mean, cov, n=samples, rng=rng))) 1abcde

812 # TODO when I add vectorized sampling, use it here 

813 if not transformed: 813 ↛ 815line 813 didn't jump to line 815 because the condition on line 813 was always true1abcde

814 sample = self._to_data(hp, sample) 1abcde

815 return sample 1abcde

816 

817 # TODO the default should be something in data space, so with samples. 

818 # If I handle the analyitical posterior through standardize, I could 

819 # also make it without samples by default. Although I guess for 

820 # whatever calculations samples are more convenient (just do the 

821 # calculation on the samples.) 

822 

823 @functools.cached_property 1fabcde

824 def _pred(self): 1fabcde

825 

826 @functools.partial(jax.jit, static_argnums=(8,)) 1abcde

827 def _pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw, error): 1abcde

828 gp = self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw) 1abcde

829 data = self.fit.data(hp, **gpfactorykw) 1abcde

830 if z is None: 1abcde

831 label = 'train' 1bcde

832 else: 

833 label = 'test' 1a

834 if not error: 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true1abcde

835 label += 'mean' 

836 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde

837 return outmean + hp.get('m', 0), outcov 1abcde

838 

839 # TODO make everything pure and jit this per class instead of per 

840 # instance 

841 

842 return _pred 1abcde

843 

844 def from_data(self, y, *, hp='map', rng=None): 1fabcde

845 """ 

846 Transforms outcomes :math:`y` to the regression variable :math:`\\eta`. 

847 

848 Parameters 

849 ---------- 

850 y : (n,) array 

851 Outcomes. 

852 hp : str or dict 

853 The hyperparameters to use. If ``'map'`` (default), use the marginal 

854 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

855 the posterior. If a dict, use the given hyperparameters. 

856 rng : numpy.random.Generator, optional 

857 Random number generator, used if ``hp == 'sample'``. 

858 

859 Returns 

860 ------- 

861 eta : (n,) array 

862 Transformed outcomes. 

863 """ 

864 

865 hp = self._gethp(hp, rng) 1bcde

866 return self._from_data(hp, y) 1bcde

867 

868 def to_data(self, eta, *, hp='map', rng=None): 1fabcde

869 """ 

870 Convert the regression variable :math:`\\eta` to outcomes :math:`y`. 

871 

872 Parameters 

873 ---------- 

874 eta : (n,) array 

875 Transformed outcomes. 

876 hp : str or dict 

877 The hyperparameters to use. If ``'map'`` (default), use the marginal 

878 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

879 the posterior. If a dict, use the given hyperparameters. 

880 rng : numpy.random.Generator, optional 

881 Random number generator, used if ``hp == 'sample'``. 

882 

883 Returns 

884 ------- 

885 y : (n,) array 

886 Outcomes. 

887 """ 

888 

889 hp = self._gethp(hp, rng) 1bcde

890 return self._to_data(hp, eta) 1bcde

891 

892 @classmethod 1fabcde

893 def _to_structured(cls, x, *, check_numerical=True): 1fabcde

894 

895 # convert to StructuredArray 

896 if hasattr(x, 'columns'): 1abcde

897 x = _array.StructuredArray.from_dataframe(x) 1a

898 elif hasattr(x, 'to_numpy'): 898 ↛ 899line 898 didn't jump to line 899 because the condition on line 898 was never true1abcde

899 x = _array.StructuredArray.from_dict({ 

900 'f0' if x.name is None else x.name: x.to_numpy() 

901 }) 

902 elif x.dtype.names is None: 1abcde

903 x = _array.unstructured_to_structured(x) 1bcde

904 else: 

905 x = _array.StructuredArray(x) 1a

906 

907 # check fields are numerical, for BART 

908 if check_numerical: 908 ↛ 916line 908 didn't jump to line 916 because the condition on line 908 was always true1abcde

909 assert x.ndim == 1 1abcde

910 assert x.size > len(x.dtype) 1abcde

911 def check_numerical(path, dtype): 1abcde

912 if not numpy.issubdtype(dtype, numpy.number): 912 ↛ 913line 912 didn't jump to line 913 because the condition on line 912 was never true1abcde

913 raise TypeError(f'covariate `{path}` is not numerical') 

914 cls._walk_dtype(x.dtype, check_numerical) 1abcde

915 

916 return x 1abcde

917 

918 @staticmethod 1fabcde

919 def _to_vector(x): 1fabcde

920 if hasattr(x, 'columns'): # dataframe 920 ↛ 921line 920 didn't jump to line 921 because the condition on line 920 was never true1abcde

921 x = x.to_numpy().squeeze(axis=1) 

922 elif hasattr(x, 'to_numpy'): # series (dataframe column) 1abcde

923 x = x.to_numpy() 1a

924 x = jnp.asarray(x) 1abcde

925 if x.ndim != 1: 925 ↛ 926line 925 didn't jump to line 926 because the condition on line 925 was never true1abcde

926 raise ValueError(f'array is not 1d vector, ndim={x.ndim}') 

927 return x 1abcde

928 

929 @classmethod 1fabcde

930 def _walk_dtype(cls, dtype, task, path=None): 1fabcde

931 if dtype.names is None: 1abcde

932 task(path, dtype) 1abcde

933 else: 

934 for name in dtype.names: 1abcde

935 subpath = name if path is None else path + ':' + name 1abcde

936 cls._walk_dtype(dtype[name], task, subpath) 1abcde

937 

938 @staticmethod 1fabcde

939 def _toindices(x, splits): 1fabcde

940 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde

941 dtype = cast(x.dtype, ix.dtype) 1abcde

942 return _array.unstructured_to_structured(ix, dtype=dtype) 1abcde

943 

944 def __repr__(self): 1fabcde

945 

946 with _gvarext.gvar_format(): 1a

947 if hasattr(self.m, 'sdev'): 947 ↛ 948line 947 didn't jump to line 948 because the condition on line 947 was never true1a

948 m = str(self.m) 

949 else: 

950 m = f'{self.m:.3g}' 1a

951 

952 n = self.fit.gpfactorykw['y'].size 1a

953 p_mu = _array._nd(self.fit.gpfactorykw['i_mu']['x'].dtype) 1a

954 x_tau = self.fit.gpfactorykw['i_tau'] 1a

955 x_aux = self.fit.gpfactorykw['x_aux'] 1a

956 

957 out = f"""\ 1a

958Data: 

959 n = {n}""" 

960 

961 if x_tau is None: 961 ↛ 965line 961 didn't jump to line 965 because the condition on line 961 was always true1a

962 out += f""" 1a

963 p = {p_mu}""" 

964 else: 

965 p_tau = _array._nd(x_tau['x'].dtype) 

966 out += f""" 

967 p_mu/tau = {p_mu}, {p_tau}""" 

968 

969 if x_aux is not None: 969 ↛ 970line 969 didn't jump to line 970 because the condition on line 969 was never true1a

970 p_aux = _array._nd(x_aux['x'].dtype) 

971 out += f""" 

972 p_aux = {p_aux}""" 

973 

974 out += f""" 1a

975Hyperparameter posterior: 

976 m = {m} 

977 z_0 = {self.z_0} 

978 alpha_mu/tau = {self.alpha_mu} {self.alpha_tau} 

979 beta_mu/tau = {self.beta_mu} {self.beta_tau} 

980 lambda_mu/tau = {self.lambda_mu} {self.lambda_tau}""" 

981 

982 weights = self.fit.gpfactorykw['weights'] 1a

983 if weights is None: 983 ↛ 988line 983 didn't jump to line 988 because the condition on line 983 was always true1a

984 out += f""" 1a

985 sigma = {self.sigma}""" 

986 

987 else: 

988 weights = numpy.array(weights) # to avoid jax taking over the ops 

989 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights)) 

990 out += f""" 

991 sqrt(mean(sigma^2/w)) = {avgsigma} 

992 sigma = {self.sigma}""" 

993 

994 out += """ 1a

995Meaning of hyperparameters: 

996 mu(x) = reference outcome level 

997 tau(x) = effect of the treatment 

998 z_0 in (0, 1): reference treatment level 

999 z_0 -> 0: mu is the model of the untreated 

1000 z_0 -> 1: mu is the model of the treated 

1001 alpha in (0, 1) 

1002 alpha -> 0: constant function 

1003 alpha -> 1: no constraints on the function 

1004 beta in (0, ∞) 

1005 beta -> 0: no constraints on the function 

1006 beta -> ∞: no interactions, f(x) = f1(x1) + f2(x2) + ... 

1007 lambda in (0, ∞): standard deviation of function 

1008 lambda small: confident extrapolation 

1009 lambda large: conservative extrapolation 

1010 sigma in (0, ∞): standard deviation of i.i.d. error""" 

1011 

1012 return _utils.top_bottom_rule('BCF', out) 1a

1013 

1014 # TODO print user parameters, applying transformations. Copy the dict and use .pop() to remove the predefined params as they are printed. 

1015 

1016 def _get_transf(self, *, transf, y, weights): 1fabcde

1017 

1018 from_datas = [] 1abcde

1019 to_datas = [] 1abcde

1020 hypers = {} 1abcde

1021 

1022 if transf is None: 1022 ↛ 1023line 1022 didn't jump to line 1023 because the condition on line 1022 was never true1abcde

1023 transf = [] 

1024 elif isinstance(transf, list): 1abcde

1025 name = lambda n: f'transf{i}_{n}' 1abcde

1026 else: 

1027 name = lambda n: n 1bcde

1028 transf = [transf] 1bcde

1029 

1030 for i, tr in enumerate(transf): 1abcde

1031 

1032 hyper = {} 1abcde

1033 

1034 if not isinstance(tr, str): 1034 ↛ 1036line 1034 didn't jump to line 1036 because the condition on line 1034 was never true1abcde

1035 

1036 from_data, to_data = tr 

1037 

1038 elif tr == 'standardize': 1abcde

1039 

1040 if i > 0: 1040 ↛ 1041line 1040 didn't jump to line 1041 because the condition on line 1040 was never true1abcde

1041 warnings.warn('standardization applied after other ' 

1042 'transformations: standardization always uses the ' 

1043 'initial data mean and standard deviation, so it may ' 

1044 'not work as intended') 

1045 

1046 # It's not possible to overcome this limitation if one wants 

1047 # to stick to transformations that act on one point at a 

1048 # time to make them generalizable out of sample. 

1049 

1050 if weights is None: 1050 ↛ 1054line 1050 didn't jump to line 1054 because the condition on line 1050 was always true1abcde

1051 loc = jnp.mean(y) 1abcde

1052 scale = jnp.std(y) 1abcde

1053 else: 

1054 loc = jnp.average(y, weights=weights) 

1055 scale = jnp.sqrt(jnp.average((y - loc) ** 2, weights=weights)) 

1056 

1057 def from_data(hp, y): 1abcde

1058 return (y - loc) / scale 1abcde

1059 def to_data(hp, eta): 1abcde

1060 return loc + scale * eta 1abcde

1061 

1062 elif tr == 'yeojohnson': 1062 ↛ 1071line 1062 didn't jump to line 1071 because the condition on line 1062 was always true1abcde

1063 

1064 def from_data(hp, y): 1abcde

1065 return yeojohnson(y, hp[name('lambda_yj')]) 1abcde

1066 def to_data(hp, eta): 1abcde

1067 return yeojohnson_inverse(eta, hp[name('lambda_yj')]) 1abcde

1068 hyper[name('lambda_yj')] = 2 * copula.beta(2, 2) 1abcde

1069 

1070 else: 

1071 raise KeyError(tr) 

1072 

1073 from_datas.append(from_data) 1abcde

1074 to_datas.append(to_data) 1abcde

1075 hypers.update(hyper) 1abcde

1076 

1077 if transf: 1077 ↛ 1087line 1077 didn't jump to line 1087 because the condition on line 1077 was always true1abcde

1078 def from_data(hp, y): 1abcde

1079 for fd in from_datas: 1abcde

1080 y = fd(hp, y) 1abcde

1081 return y 1abcde

1082 def to_data(hp, eta): 1abcde

1083 for td in reversed(to_datas): 1abcde

1084 eta = td(hp, eta) 1abcde

1085 return eta 1abcde

1086 else: 

1087 from_data = lambda hp, y: y 

1088 to_data = lambda hp, eta: eta 

1089 

1090 from_data_grad = _jaxext.elementwise_grad(from_data, 1) 1abcde

1091 def loss(hp): 1abcde

1092 return -jnp.sum(jnp.log(from_data_grad(hp, y))) 1abcde

1093 

1094 hypers = copula.makedict(hypers) 1abcde

1095 

1096 return from_data, to_data, loss, hypers 1abcde

1097 

1098def yeojohnson(x, lmbda): 1fabcde

1099 """ Yeo-Johnson transformation with lamda != 0, 2 """ 

1100 return jnp.where( 1abcde

1101 x >= 0, 

1102 (jnp.power(x + 1, lmbda) - 1) / lmbda, 

1103 -((jnp.power(-x + 1, 2 - lmbda) - 1) / (2 - lmbda)) 

1104 ) 

1105 

1106 # TODO 

1107 # - rewrite the cases with expm1, log1p, etc. to make them accurate 

1108 # - split the cases into lambda 0/2 

1109 # - make custom_jvps for the singular points to define derivatives w.r.t. 

1110 # lambda even though it does not appear in the expression 

1111 # - add unit tests that check gradients with finite differences 

1112 

1113def yeojohnson_inverse(y, lmbda): 1fabcde

1114 return jnp.where( 1abcde

1115 y >= 0, 

1116 jnp.power(y * lmbda + 1, 1 / lmbda) - 1, 

1117 -jnp.power(-(2 - lmbda) * y + 1, 1 / (2 - lmbda)) + 1 

1118 )