Coverage for src/lsqfitgp/bayestree/_bcf.py: 80%

359 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 13:39 +0000

1# lsqfitgp/bayestree/_bcf.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1fabcde

21import warnings 1fabcde

22 

23import numpy 1fabcde

24from scipy import stats 1fabcde

25from jax import numpy as jnp 1fabcde

26import jax 1fabcde

27import gvar 1fabcde

28 

29from .. import copula 1fabcde

30from .. import _kernels 1fabcde

31from .. import _fit 1fabcde

32from .. import _array 1fabcde

33from .. import _GP 1fabcde

34from .. import _fastraniter 1fabcde

35from .. import _jaxext 1fabcde

36from .. import _gvarext 1fabcde

37from .. import _utils 1fabcde

38 

39# TODO add methods or options to do causal inference stuff, e.g., impute missing 

40# outcomes, or ate, att, cate, catt, sate, satt. Remember that the effect may 

41# also depend on aux. See bartCause, possibly copy its naming. 

42 

43def _recursive_cast(dtype, default, mapping): 1fabcde

44 if dtype in mapping: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1abcde

45 return mapping[dtype] 

46 elif dtype.names is not None: 1abcde

47 return numpy.dtype([ 1abcde

48 (name, _recursive_cast(dtype[name], default, mapping)) 

49 for name in dtype.names 

50 ]) 

51 elif dtype.subdtype is not None: 51 ↛ 53line 51 didn't jump to line 53 because the condition on line 51 was never true1abcde

52 # note: has names => does not have subdtype 

53 return numpy.dtype((_recursive_cast(dtype.base, default, mapping), dtype.shape)) 

54 elif default is None: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true1abcde

55 return dtype 

56 else: 

57 return default 1abcde

58 

59def cast(dtype, default, mapping={}): 1fabcde

60 """ 

61 Recursively cast a numpy data type. 

62 

63 Parameters 

64 ---------- 

65 dtype : dtype 

66 The data type to cast. 

67 default : dtype or None 

68 The leaf fields of `dtype` are casted to `default`, which can be 

69 structured, unless they appear in `mapping`. If None, dtypes not in 

70 `mapping` are left unchanged. 

71 mapping : dict 

72 A dictionary from dtypes to dtypes, indicating specific casting rules. 

73 The dtypes can be structured, a match of a structured dtype takes 

74 precedence over matches in its leaves, and the converted dtype is not 

75 further searched for matches. 

76 

77 Returns 

78 ------- 

79 casted_dtype : dtype 

80 The casted version of `dtype`. May not have the same structure if 

81 `mapping` contains structured dtypes. 

82 """ 

83 mapping = {numpy.dtype(k): numpy.dtype(v) for k, v in mapping.items()} 1abcde

84 default = None if default is None else numpy.dtype(default) 1abcde

85 return _recursive_cast(numpy.dtype(dtype), default, mapping) 1abcde

86 

87 # TODO 

88 # - move this to generic utils 

89 # - make unit tests 

90 

91class bcf: 1fabcde

92 

93 def __init__(self, *, 1fabcde

94 y, 

95 z, 

96 x_mu, 

97 x_tau=None, 

98 pihat, 

99 include_pi='mu', 

100 weights=None, 

101 fitkw={}, 

102 kernelkw_mu={}, 

103 kernelkw_tau={}, 

104 marginalize_mean=True, 

105 gpaux=None, 

106 x_aux=None, 

107 otherhp={}, 

108 transf='standardize', 

109 ): 

110 r""" 

111 Nonparametric Bayesian regression with a GP version of BCF. 

112 

113 BCF (Bayesian Causal Forests) is a regression method for observational 

114 causal inference studies introduced in [1]_ based on a pair of BART 

115 models. 

116 

117 This class evaluates a Gaussian process regression with a kernel which 

118 accurately approximates BCF in the infinite trees limit of each BART 

119 model. The hyperparameters are optimized to their marginal MAP. 

120 

121 The model is (loosely, see notes below) :math:`y = \mu(x) + z\tau(x)`, 

122 so :math:`\tau(x)` is the expected causal effect of :math:`z` on 

123 :math:`y` at location :math:`x`. 

124 

125 Parameters 

126 ---------- 

127 y : (n,) array, series or dataframe 

128 Outcome. 

129 z : (n,) array, series or dataframe 

130 Binary treatment status: 0 control group, 1 treatment group. 

131 x_mu : (n, p) array, series or dataframe 

132 Covariates for the :math:`\mu` model. 

133 x_tau : (n, q) array, series or dataframe, optional 

134 Covariates for the :math:`\tau` model. If not specified, use `x_mu`. 

135 pihat : (n,) array, series or dataframe 

136 Estimated propensity score, i.e., P(Z=1|X). 

137 include_pi : {'mu', 'tau', 'both'}, optional 

138 Whether to include the propensity score in the :math:`\mu` model, 

139 the :math:`\tau` model, or both. Default is ``'mu'``. 

140 weights : (n,) array, series or dataframe 

141 Weights used to rescale the error variance (as 1 / weight). 

142 fitkw : dict 

143 Additional arguments passed to `~lsqfitgp.empbayes_fit`, overrides 

144 the defaults. 

145 kernelkw_mu, kernelkw_tau : dict 

146 Additional arguments passed to `~lsqfitgp.BART` for each model, 

147 overrides the defaults. 

148 marginalize_mean : bool 

149 If True (default), marginalize the intercept of the model. 

150 gpaux : callable, optional 

151 If specified, this function is called with a pair ``(hp, gp)``, 

152 where ``hp`` is a dictionary of hyperparameters, and ``gp`` is a 

153 `~lsqfitgp.GP` object under construction, and is expected to return 

154 a modified ``gp`` with a new process named ``'aux'`` defined with 

155 `~lsqfitgp.GP.defproc` or similar. The process is added to the 

156 regression model. The input to the process is a structured array 

157 with fields: 

158 

159 'train' : bool 

160 Indicates whether the data is training set (the one passed on 

161 initialization) or test set (the one passed to `pred` or `gp`). 

162 'i' : int 

163 Index of the flattened array. 

164 'z' : int 

165 Treatment status. 

166 'mu', 'tau' : structured 

167 The values in `x_mu` and `x_tau`, converted to indices according 

168 to the BART grids. Where `pihat` has been added, there are two 

169 subfields: ``'x'`` which contains the covariates, and 

170 ``'pihat'``, the latter expressed in indices as well. 

171 'pihat' : float 

172 The `pihat` argument. Contrary to the subfield included under 

173 ``'mu'`` and/or ``'tau'``, this field contains the original 

174 values. 

175 'aux' : structured 

176 The values in `x_aux`, if specified. 

177  

178 x_aux : (n, k) array, series or dataframe, optional 

179 Additional covariates for the ``'aux'`` process. 

180 otherhp : dictionary of gvar 

181 A dictionary with the prior of arbitrary additional hyperpameters, 

182 intended to be used by ``gpaux`` or ``transf``. 

183 transf : (list of) str or pair of callable 

184 Data transformation. Either a string indicating a pre-defined 

185 transformation, or a pair ``(from_data, to_data)``, two functions 

186 with signatures ``from_data(hp, y) -> eta`` and ``to_data(hp, eta) 

187 -> y``, where ``eta`` is the value to which the model is fit, and 

188 ``hp`` is the dictionary of hyperparameters. The functions must be 

189 ufuncs and one the inverse of the other w.r.t. the second parameter. 

190 ``from_data`` must be derivable with `jax` w.r.t. ``y``. 

191 

192 If a list of such specifications is provided, the transformations 

193 are applied in order, with the first one being the outermost, i.e., 

194 the one applied first to the data. 

195 

196 If a transformation uses additional hyperparameters, either 

197 predefined automatically or passed by the user through `otherhp`, 

198 they are inferred with the rest of the hyperparameters. 

199 

200 The pre-defined transformations are: 

201 

202 'standardize' (default) 

203 eta = (y - mean(train_y)) / sdev(train_y) 

204 'yeojohnson' 

205 The Yeo-Johnson transformation [2]_ to reduce skewness. The 

206 :math:`\lambda` parameter is bounded in :math:`(0, 2)` 

207 for implementation convenience, this restriction may be lifted 

208 in future versions. 

209  

210 Notes 

211 ----- 

212 The regression model is: 

213 

214 .. math:: 

215 \eta_i = g(y_i; \ldots) &= m + {} \\ 

216 &\phantom{{}={}} + 

217 \lambda_\mu 

218 \mu(\mathbf x^\mu_i, \hat\pi_i?) + {} \\ 

219 &\phantom{{}={}} + 

220 \lambda_\tau 

221 \tau(\mathbf x^\tau_i, \hat\pi_i?) (z_i - z_0) + {} \\ 

222 &\phantom{{}={}} + 

223 \mathrm{aux}(i, z_i, \mathbf x^\mu_i, \mathbf x^\tau_i, 

224 \hat\pi_i, \mathbf x^\text{aux}_i) + {} \\ 

225 &\phantom{{}={}} + 

226 \varepsilon_i, \\ 

227 \varepsilon_i &\sim 

228 N(0, \sigma^2 / w_i), \\ 

229 m &\sim N(0, 1), \\ 

230 \log \sigma^2 &\sim N(\log\bar w, 4), \\ 

231 \lambda_\mu 

232 &\sim \mathrm{HalfCauchy}(2), \\ 

233 \lambda_\tau 

234 &\sim \mathrm{HalfNormal}(1.48), \\ 

235 \mu &\sim \mathrm{GP}(0, 

236 \mathrm{BART}(\alpha_\mu, \beta_\mu) ), \\ 

237 \tau &\sim \mathrm{GP}(0, 

238 \mathrm{BART}(\alpha_\tau, \beta_\tau) ), \\ 

239 \mathrm{aux} & \sim \mathrm{GP}(0, \text{<user defined>}), \\ 

240 \alpha_\mu, \alpha_\tau &\sim \mathrm{Beta}(2, 1), \\ 

241 \beta_\mu, \beta_\tau &\sim \mathrm{InvGamma}(1, 1), \\ 

242 z_0 &\sim U(0, 1), 

243 

244 To make the inference, :math:`(\mu, \tau, \boldsymbol\varepsilon, m, 

245 \mathrm{aux})` are marginalized analytically, and the marginal posterior 

246 mode of 

247 :math:`(\sigma, \lambda_*, \alpha_*, \beta_*, z_0, \ldots)` is found by 

248 numerical minimization, after transforming them to express their prior 

249 as a Gaussian copula. Their marginal posterior covariance matrix is 

250 estimated with an approximation of the hessian inverse. See 

251 `~lsqfitgp.empbayes_fit` and use the parameter ``fitkw`` to customize 

252 this procedure. 

253 

254 The tree splitting grid of the BART kernel is set using quantiles of the 

255 observed covariates. This corresponds to settings ``usequants=True``, 

256 ``numcut=inf`` in the R packages BayesTree and BART. Use the parameters 

257 `kernelkw_mu` and `kernelkw_tau` to customize the grids. 

258 

259 The difference between the regression model evaluated at :math:`Z=1` vs. 

260 :math:`Z=0` can be interpreted as the causal effect :math:`Z \rightarrow 

261 Y` if the unconfoundedness assumption is made: 

262 

263 .. math:: 

264 \{Y(Z=0), Y(Z=1)\} \perp\!\!\!\perp Z \mid X. 

265 

266 In practical terms, this holds when: 

267 

268 1) :math:`X` are pre-treatment variables, i.e., they represent 

269 quantities causally upstream of :math:`Z`. 

270 

271 2) :math:`X` are sufficient to adjust for all common causes of 

272 :math:`Z` and :math:`Y`, such that the only remaining difference 

273 is the causal effect and not just a correlation. 

274 

275 Here :math:`X` consists in `x_tau`, `x_mu` and `x_aux`. However these 

276 arrays may also be used to pass "technical" values used to set up the 

277 model, that do not satisfy the uncounfoundedness assumption, if you know 

278 what you are doing. 

279 

280 Attributes 

281 ---------- 

282 m : float or gvar 

283 The prior mean :math:`m`. 

284 sigma : gvar 

285 The error term standard deviation :math:`\sigma`. If there are 

286 weights, the sdev for each unit is obtained dividing ``sigma`` by 

287 sqrt(weight). 

288 alpha_mu, alpha_tau : gvar 

289 The numerator of the tree spawn probability :math:`\alpha_*` (named 

290 ``base`` in R bcf). 

291 beta_mu, beta_tau : gvar 

292 The depth exponent of the tree spawn probability :math:`\beta_*` 

293 (named ``power`` in R bcf). 

294 lambda_mu, lambda_tau : gvar 

295 The prior standard deviation :math:`\lambda_*`. 

296 z_0 : gvar 

297 The treatment coding parameter. 

298 fit : empbayes_fit 

299 The hyperparameters fit object. 

300 

301 Methods 

302 ------- 

303 gp : 

304 Create a GP object. 

305 data : 

306 Creates the dictionary to be passed to `GP.pred` to represent data. 

307 pred : 

308 Evaluate the regression function at given locations. 

309 from_data : 

310 Convert :math:`y` to :math:`\eta`. 

311 to_data : 

312 Convert :math:`\eta` to :math:`y`. 

313 

314 See also 

315 -------- 

316 lsqfitgp.BART 

317  

318 References 

319 ---------- 

320 .. [1] P. Richard Hahn, Jared S. Murray, Carlos M. Carvalho "Bayesian 

321 Regression Tree Models for Causal Inference: Regularization, 

322 Confounding, and Heterogeneous Effects (with Discussion)," Bayesian 

323 Analysis 15(3), 965-1056, September 2020, 

324 https://doi.org/10.1214/19-BA1195 

325 .. [2] Yeo, In-Kwon; Johnson, Richard A. (2000). "A New Family of Power 

326 Transformations to Improve Normality or Symmetry". Biometrika. 87 

327 (4): 954–959. https://doi.org/10.1093/biomet/87.4.954 

328 """ 

329 

330 # convert covariates to StructuredArray 

331 x_mu = self._to_structured(x_mu) 1abcde

332 if x_tau is not None: 1abcde

333 x_tau = self._to_structured(x_tau) 1bcde

334 assert x_tau.shape == x_mu.shape 1bcde

335 if x_aux is not None: 1abcde

336 x_aux = self._to_structured(x_aux) 1bcde

337 assert x_aux.shape == x_mu.shape 1bcde

338 

339 # convert outcomes, treatment, propensity score, weights to 1d arrays 

340 y = self._to_vector(y) 1abcde

341 z = self._to_vector(z) 1abcde

342 pihat = self._to_vector(pihat) 1abcde

343 assert y.shape == z.shape == pihat.shape == x_mu.shape 1abcde

344 if weights is not None: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true1abcde

345 weights = self._to_vector(weights) 

346 assert weights.shape == x_mu.shape 

347 

348 # check include_pi 

349 if include_pi not in ('mu', 'tau', 'both'): 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true1abcde

350 raise KeyError(f'invalid value include_pi={include_pi!r}') 

351 self._include_pi = include_pi 1abcde

352 

353 # add pihat to covariates 

354 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1abcde

355 

356 # grid and indices 

357 splits_mu = _kernels.BART.splits_from_coord(x_mu) 1abcde

358 i_mu = self._toindices(x_mu, splits_mu) 1abcde

359 if x_tau is None: 1abcde

360 splits_tau = splits_mu 1abcde

361 i_tau = None 1abcde

362 else: 

363 splits_tau = _kernels.BART.splits_from_coord(x_tau) 1bcde

364 i_tau = self._toindices(x_tau, splits_tau) 1bcde

365 

366 # get functions for data transformation 

367 from_data, to_data, transfloss, transfhp = self._get_transf( 1abcde

368 transf=transf, weights=weights, y=y) 

369 

370 # scale of error variance 

371 logsigma2_loc = 0 if weights is None else numpy.log(jnp.mean(weights)) 1abcde

372 

373 # prior on hyperparams 

374 hyperprior = copula.makedict({ 1abcde

375 'm': gvar.gvar(0, 1), 

376 'sigma^2': copula.lognorm(logsigma2_loc, 2), 

377 'lambda_mu': copula.halfcauchy(2), 

378 'lambda_tau': copula.halfnorm(1.48), 

379 'alpha_mu': copula.beta(2, 1), 

380 'alpha_tau': copula.beta(2, 1), 

381 'beta_mu': copula.invgamma(1, 1), 

382 'beta_tau': copula.invgamma(1, 1), 

383 'z_0': copula.uniform(0, 1), 

384 }) 

385 

386 # remove explicit mean parameter if it's baked into the Gaussian process 

387 if marginalize_mean: 1abcde

388 hyperprior.pop('m') 1abcde

389 

390 # add data transformation and user hyperparameters 

391 def update_hyperparams(new, newname, raises): 1abcde

392 new = gvar.BufferDict(new) 1abcde

393 for key in new.all_keys(): 1abcde

394 if hyperprior.has_dictkey(key): 394 ↛ 395line 394 didn't jump to line 395 because the condition on line 394 was never true1abcde

395 message = f'{newname} hyperparameter {key!r} overrides existing one' 

396 if raises: 

397 raise ValueError(message) 

398 else: 

399 warnings.warn(message) 

400 hyperprior.update(new) 1abcde

401 update_hyperparams(transfhp, 'data transformation', True) 1abcde

402 # the hypers handed by _get_transf are not allowed to override 

403 update_hyperparams(otherhp, 'user', False) 1abcde

404 

405 # GP factory 

406 def gpfactory(hp, *, z, i_mu, i_tau, pihat, x_aux, weights, 1abcde

407 splits_mu, splits_tau, **_): 

408 

409 # TODO maybe I should pass kernelkw_* as arguments, but they may not 

410 # be jittable. I need jitkw in empbayes_fit for that. 

411 

412 kw_overridable = dict( 1abcde

413 maxd=10, 

414 reset=[2, 4, 6, 8], 

415 intercept=False, 

416 ) 

417 kw_not_overridable = dict(indices=True) 1abcde

418 

419 gp = _GP.GP(checkpos=False, checksym=False, solver='chol') 1abcde

420 

421 for name, kernelkw in dict(mu=kernelkw_mu, tau=kernelkw_tau).items(): 1abcde

422 kw = dict( 1abcde

423 alpha=hp[f'alpha_{name}'], 

424 beta=hp[f'beta_{name}'], 

425 dim=name, 

426 splits=eval(f'splits_{name}'), 

427 **kw_overridable, 

428 ) 

429 kw.update(kernelkw) 1abcde

430 kernel = _kernels.BART(**kw, **kw_not_overridable) 1abcde

431 kernel *= hp[f'lambda_{name}'] ** 2 1abcde

432 

433 gp = gp.defproc(name, kernel) 1abcde

434 

435 if 'm' in hp: 1abcde

436 kernel_mean = 0 * _kernels.Constant() 1bcde

437 else: 

438 kernel_mean = _kernels.Constant() 1abcde

439 gp = gp.defproc('m', kernel_mean) 1abcde

440 

441 if gpaux is None: 1abcde

442 gp = gp.defproc('aux', 0 * _kernels.Constant()) 1abcde

443 else: 

444 gp = gpaux(hp, gp) 1bcde

445 

446 gp = gp.deflintransf( 1abcde

447 gp.DefaultProcess, 

448 lambda m, mu, tau, aux: lambda x: 

449 m(x) + mu(x) + tau(x) * (x['z'] - hp['z_0']) + aux(x), 

450 ['m', 'mu', 'tau', 'aux'], 

451 ) 

452 

453 x = self._join_points(True, z, i_mu, i_tau, pihat, x_aux) 1abcde

454 gp = gp.addx(x, 'trainmean') 1abcde

455 errcov = self._error_cov(hp, weights, x) 1abcde

456 return (gp 1abcde

457 .addcov(errcov, 'trainnoise') 

458 .addtransf({'trainmean': 1, 'trainnoise': 1}, 'train') 

459 ) 

460 

461 # data factory 

462 def data(hp, *, y, **_): 1abcde

463 return {'train': from_data(hp, y) - hp.get('m', 0)} 1abcde

464 

465 # fit hyperparameters 

466 options = dict( 1abcde

467 verbosity=3, 

468 minkw=dict( 

469 method='l-bfgs-b', 

470 options=dict( 

471 maxls=4, 

472 maxiter=100, 

473 ), 

474 ), 

475 mlkw=dict( 

476 epsrel=0, 

477 ), 

478 forward=True, 

479 gpfactorykw=dict( 

480 y=y, 

481 z=z, 

482 i_mu=i_mu, 

483 i_tau=i_tau, 

484 pihat=pihat, 

485 x_aux=x_aux, 

486 weights=weights, 

487 splits_mu=splits_mu, 

488 splits_tau=splits_tau, 

489 ), 

490 additional_loss=transfloss, 

491 ) 

492 options.update(fitkw) 1abcde

493 fit = _fit.empbayes_fit(hyperprior, gpfactory, data, **options) 1abcde

494 

495 # extract hyperparameters from minimization result 

496 self.m = fit.p.get('m', 0) 1abcde

497 self.sigma = gvar.sqrt(fit.p['sigma^2']) 1abcde

498 self.lambda_mu = fit.p['lambda_mu'] 1abcde

499 self.lambda_tau = fit.p['lambda_tau'] 1abcde

500 self.alpha_mu = fit.p['alpha_mu'] 1abcde

501 self.alpha_tau = fit.p['alpha_tau'] 1abcde

502 self.beta_mu = fit.p['beta_mu'] 1abcde

503 self.beta_tau = fit.p['beta_tau'] 1abcde

504 self.z_0 = fit.p['z_0'] 1abcde

505 

506 # save other attributes 

507 self.fit = fit 1abcde

508 self._from_data = from_data 1abcde

509 self._to_data = to_data 1abcde

510 

511 def _append_pihat(self, x_mu, x_tau, pihat): 1fabcde

512 ip = self._include_pi 1abcde

513 if ip == 'mu' or ip == 'both': 1abcde

514 x_mu = _array.StructuredArray.from_dict(dict( 1abcde

515 x=x_mu, 

516 pihat=pihat, 

517 )) 

518 if x_tau is not None and (ip == 'tau' or ip == 'both'): 1abcde

519 x_tau = _array.StructuredArray.from_dict(dict( 1bcde

520 x=x_tau, 

521 pihat=pihat, 

522 )) 

523 return x_mu, x_tau 1abcde

524 

525 @staticmethod 1fabcde

526 def _join_points(train, z, i_mu, i_tau, pihat, x_aux): 1fabcde

527 """ join covariates into a single StructuredArray """ 

528 columns = dict( 1abcde

529 train=jnp.broadcast_to(bool(train), z.shape), 

530 i=jnp.arange(z.size).reshape(z.shape), 

531 z=z, 

532 mu=i_mu, 

533 tau=i_mu if i_tau is None else i_tau, 

534 pihat=pihat, 

535 ) 

536 if x_aux is not None: 1abcde

537 columns.update(aux=x_aux) 1bcde

538 return _array.StructuredArray.from_dict(columns) 1abcde

539 

540 @staticmethod 1fabcde

541 def _error_cov(hp, weights, x): 1fabcde

542 """ fill error covariance matrix """ 

543 if weights is None: 543 ↛ 546line 543 didn't jump to line 546 because the condition on line 543 was always true1abcde

544 error_var = jnp.broadcast_to(hp['sigma^2'], len(x)) 1abcde

545 else: 

546 error_var = hp['sigma^2'] / weights 

547 return jnp.diag(error_var) 1abcde

548 

549 def _gethp(self, hp, rng): 1fabcde

550 if not isinstance(hp, str): 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true1abcde

551 return hp 

552 elif hp == 'map': 1abcde

553 return self.fit.pmean 1abcde

554 elif hp == 'sample': 554 ↛ 557line 554 didn't jump to line 557 because the condition on line 554 was always true1abcde

555 return _fastraniter.sample(self.fit.pmean, self.fit.pcov, rng=rng) 1abcde

556 else: 

557 raise KeyError(hp) 

558 

559 def gp(self, *, hp='map', z=None, x_mu=None, x_tau=None, pihat=None, 1fabcde

560 x_aux=None, weights=None, rng=None): 

561 """ 

562 Create a Gaussian process with the fitted hyperparameters. 

563 

564 Parameters 

565 ---------- 

566 hp : str or dict 

567 The hyperparameters to use. If ``'map'`` (default), use the marginal 

568 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

569 the posterior. If a dict, use the given hyperparameters. 

570 z : (m,) array, series or dataframe, optional 

571 Treatment status at test points. If specified, also `x_mu`, `pihat`, 

572 `x_tau` and `x_aux` (the latter two if and only also specified at 

573 initialization) must be specified. 

574 x_mu : (m, p) array, series or dataframe, optional 

575 Control model covariates at test points. 

576 x_tau : (m, q) array, series or dataframe, optional 

577 Moderating model covariates at test points. 

578 pihat : (m,) array, series or dataframe, optional 

579 Estimated propensity score at test points. 

580 x_aux : (m, k) array, series or dataframe, optional 

581 Additional covariates for the ``'aux'`` process. 

582 weights : (m,) array, series or dataframe, optional 

583 Weights for the error variance on the test points. 

584 rng : numpy.random.Generator, optional 

585 Random number generator, used if ``hp == 'sample'``. 

586 

587 Returns 

588 ------- 

589 gp : GP 

590 A centered Gaussian process object. To add the mean, use the `m` 

591 attribute of the `bcf` object. The keys of the GP are ``'@mean'``, 

592 ``'@noise'``, and ``'@'``, where the "@" stands either for 'train' 

593 or 'test', and @ = @mean + @noise. 

594 

595 This Gaussian process is defined on the transformed data ``eta``. 

596 """ 

597 

598 hp = self._gethp(hp, rng) 

599 return self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw) 

600 

601 def _gp(self, hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw): 1fabcde

602 """ 

603 Internal function to create the GP object. This function must work 

604 both if the arguments are user-provided and need to be checked and 

605 converted to standard format, or if they are traced jax values. 

606 """ 

607 

608 # create GP object 

609 gp = self.fit.gpfactory(hp, **gpfactorykw) 1abcde

610 

611 # add test points 

612 if z is not None: 1abcde

613 

614 # check presence/absence of arguments is coherent 

615 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1a

616 

617 # check treatment and propensity score 

618 z = self._to_vector(z) 1a

619 pihat = self._to_vector(pihat) 1a

620 assert pihat.shape == z.shape 1a

621 

622 # check weights 

623 if weights is not None: 623 ↛ 624line 623 didn't jump to line 624 because the condition on line 623 was never true1a

624 weights = self._to_vector(weights) 

625 assert weights.shape == z.shape 

626 

627 # add propensity score to covariates 

628 x_mu = self._to_structured(x_mu) 1a

629 assert x_mu.shape == z.shape 1a

630 if x_tau is not None: 630 ↛ 631line 630 didn't jump to line 631 because the condition on line 630 was never true1a

631 x_tau = self._to_structured(x_tau) 

632 assert x_tau.shape == z.shape 

633 x_mu, x_tau = self._append_pihat(x_mu, x_tau, pihat) 1a

634 

635 # convert covariates to indices 

636 i_mu = self._toindices(x_mu, gpfactorykw['splits_mu']) 1a

637 assert i_mu.dtype == gpfactorykw['i_mu'].dtype 1a

638 if x_tau is not None: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true1a

639 i_tau = self._toindices(x_tau, gpfactorykw['splits_tau']) 

640 assert i_tau.dtype == gpfactorykw['i_tau'].dtype 

641 else: 

642 i_tau = None 1a

643 

644 # check auxiliary points 

645 if x_aux is not None: 645 ↛ 646line 645 didn't jump to line 646 because the condition on line 645 was never true1a

646 x_aux = self._to_structured(x_aux) 

647 

648 # add test points 

649 x = self._join_points(False, z, i_mu, i_tau, pihat, x_aux) 1a

650 gp = gp.addx(x, 'testmean') 1a

651 errcov = self._error_cov(hp, weights, x) 1a

652 gp = (gp 1a

653 .addcov(errcov, 'testnoise') 

654 .addtransf({'testmean': 1, 'testnoise': 1}, 'test') 

655 ) 

656 

657 return gp 1abcde

658 

659 def _check_coherent_covariates(self, z, x_mu, x_tau, pihat, x_aux): 1fabcde

660 if z is None: 1abcde

661 assert x_mu is None 1bcde

662 assert x_tau is None 1bcde

663 assert pihat is None 1bcde

664 assert x_aux is None 1bcde

665 else: 

666 assert x_mu is not None 1a

667 assert pihat is not None 1a

668 train_tau = self.fit.gpfactorykw['i_tau'] 1a

669 if x_tau is None: 669 ↛ 672line 669 didn't jump to line 672 because the condition on line 669 was always true1a

670 assert train_tau is None 1a

671 else: 

672 assert train_tau is not None 

673 train_aux = self.fit.gpfactorykw['x_aux'] 1a

674 if x_aux is None: 674 ↛ 677line 674 didn't jump to line 677 because the condition on line 674 was always true1a

675 assert train_aux is None 1a

676 else: 

677 assert train_aux is not None 

678 

679 def data(self, *, hp='map', rng=None): 1fabcde

680 """ 

681 Get the data to be passed to `GP.pred` on a GP object returned by `gp`. 

682 

683 Parameters 

684 ---------- 

685 hp : str or dict 

686 The hyperparameters to use. If ``'map'`` (default), use the marginal 

687 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

688 the posterior. If a dict, use the given hyperparameters. 

689 rng : numpy.random.Generator, optional 

690 Random number generator, used if ``hp == 'sample'``. 

691 

692 Returns 

693 ------- 

694 data : dict 

695 A dictionary representing ``eta`` in the format required by the 

696 `GP.pred` method. 

697 """ 

698 

699 hp = self._gethp(hp, rng) 

700 return self.fit.data(hp, **self.fit.gpfactorykw) 

701 

702 def pred(self, *, hp='map', error=False, z=None, x_mu=None, x_tau=None, 1fabcde

703 pihat=None, x_aux=None, weights=None, transformed=True, samples=None, 

704 gvars=False, rng=None): 

705 r""" 

706 Predict the transformed outcome at given locations. 

707 

708 Parameters 

709 ---------- 

710 hp : str or dict 

711 The hyperparameters to use. If ``'map'`` (default), use the marginal 

712 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

713 the posterior. If a dict, use the given hyperparameters. 

714 error : bool, default False 

715 If ``False``, make a prediction for the latent mean. If ``True``, 

716 add the error term. 

717 z : (m,) array, series or dataframe, optional 

718 Treatment status at test points. If specified, also `x_mu`, `pihat`, 

719 `x_tau` and `x_aux` (the latter two if and only also specified at 

720 initialization) must be specified. 

721 x_mu : (m, p) array, series or dataframe, optional 

722 :math:`\mu` model covariates at test points. 

723 x_tau : (m, q) array, series or dataframe, optional 

724 :math:`\tau` model covariates at test points. 

725 pihat : (m,) array, series or dataframe, optional 

726 Estimated propensity score at test points. 

727 x_aux : (m, k) array, series or dataframe, optional 

728 Additional covariates for the ``'aux'`` process at test points. 

729 weights : (m,) array, series or dataframe, optional 

730 Weights for the error variance on the test points. 

731 transformed : bool, default True 

732 If ``True``, return the prediction on the transformed outcome 

733 :math:`\eta`, else the observable outcome :math:`y`. 

734 samples : int, optional 

735 If specified, indicates the number of samples to take from the 

736 posterior. If not, return the mean and covariance matrix of the 

737 posterior. 

738 gvars : bool, default False 

739 If ``True``, return the mean and covariance matrix of the posterior 

740 as an array of `GVar` variables. 

741 rng : numpy.random.Generator, optional 

742 Random number generator, used if ``hp == 'sample'`` or ``samples`` 

743 is not `None`. 

744 

745 Returns 

746 ------- 

747 If ``samples`` is `None` and ``gvars`` is `False` (default): 

748 

749 mean, cov : (m,) and (m, m) arrays 

750 The mean and covariance matrix of the Normal posterior distribution 

751 over the regression function or :math:`\eta` at the specified 

752 locations. 

753 

754 If ``samples`` is `None` and ``gvars`` is `True`: 

755 

756 out : (m,) array of gvars 

757 The same distribution represented as an array of `~gvar.GVar` 

758 objects. 

759 

760 If ``samples`` is an integer: 

761 

762 sample : (samples, m) array 

763 Posterior samples over either the regression function, :math:`\eta`, 

764 or :math:`y`. 

765 """ 

766 

767 # check consistency of output choice 

768 if samples is None: 768 ↛ 769line 768 didn't jump to line 769 because the condition on line 768 was never true1abcde

769 if not transformed: 

770 raise ValueError('Posterior is required in analytical form ' 

771 '(samples=None) and in data space ' 

772 '(transformed=False), this is not possible as ' 

773 'the transformation model space -> data space ' 

774 'is arbitrary. Either sample the posterior or ' 

775 'get the result in model space.') 

776 else: 

777 if not transformed and not error: 777 ↛ 778line 777 didn't jump to line 778 because the condition on line 777 was never true1abcde

778 raise ValueError('Posterior is required in data space ' 

779 '(transformed=False) and without error term ' 

780 '(error=False), this is not possible as the ' 

781 'transformation model space -> data space ' 

782 'applies after adding the error.') 

783 assert not gvars, 'can not represent posterior samples as gvars' 1abcde

784 

785 # TODO allow exceptions to these rules when there are no transformations 

786 # or the only transformation is 'standardize'. 

787 

788 # get hyperparameters 

789 hp = self._gethp(hp, rng) 1abcde

790 

791 # check presence of covariates is coherent 

792 self._check_coherent_covariates(z, x_mu, x_tau, pihat, x_aux) 1abcde

793 

794 # convert all inputs to arrays compatible with jax to pass them to the 

795 # compiled implementation 

796 if z is not None: 1abcde

797 z = self._to_vector(z) 1a

798 pihat = self._to_vector(pihat) 1a

799 x_mu = self._to_structured(x_mu) 1a

800 if x_tau is not None: 800 ↛ 801line 800 didn't jump to line 801 because the condition on line 800 was never true1a

801 x_tau = self._to_structured(x_tau) 

802 if x_aux is not None: 802 ↛ 803line 802 didn't jump to line 803 because the condition on line 802 was never true1a

803 x_aux = self._to_structured(x_aux) 

804 if weights is not None: 804 ↛ 805line 804 didn't jump to line 805 because the condition on line 804 was never true1abcde

805 weights = self._to_vector(weights) 

806 

807 # GP regression 

808 mean, cov = self._pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, self.fit.gpfactorykw, bool(error)) 1abcde

809 

810 # return Normal posterior moments 

811 if samples is None: 811 ↛ 812line 811 didn't jump to line 812 because the condition on line 811 was never true1abcde

812 if gvars: 

813 return gvar.gvar(mean, cov, fast=True) 

814 else: 

815 return mean, cov 

816 

817 # sample from posterior 

818 sample = jnp.stack(list(_fastraniter.raniter(mean, cov, n=samples, rng=rng))) 1abcde

819 # TODO when I add vectorized sampling, use it here 

820 if not transformed: 820 ↛ 822line 820 didn't jump to line 822 because the condition on line 820 was always true1abcde

821 sample = self._to_data(hp, sample) 1abcde

822 return sample 1abcde

823 

824 # TODO the default should be something in data space, so with samples. 

825 # If I handle the analyitical posterior through standardize, I could 

826 # also make it without samples by default. Although I guess for 

827 # whatever calculations samples are more convenient (just do the 

828 # calculation on the samples.) 

829 

830 @functools.cached_property 1fabcde

831 def _pred(self): 1fabcde

832 

833 @functools.partial(jax.jit, static_argnums=(8,)) 1abcde

834 def _pred(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw, error): 1abcde

835 gp = self._gp(hp, z, x_mu, x_tau, pihat, x_aux, weights, gpfactorykw) 1abcde

836 data = self.fit.data(hp, **gpfactorykw) 1abcde

837 if z is None: 1abcde

838 label = 'train' 1bcde

839 else: 

840 label = 'test' 1a

841 if not error: 841 ↛ 842line 841 didn't jump to line 842 because the condition on line 841 was never true1abcde

842 label += 'mean' 

843 outmean, outcov = gp.predfromdata(data, label, raw=True) 1abcde

844 return outmean + hp.get('m', 0), outcov 1abcde

845 

846 # TODO make everything pure and jit this per class instead of per 

847 # instance 

848 

849 return _pred 1abcde

850 

851 def from_data(self, y, *, hp='map', rng=None): 1fabcde

852 """ 

853 Transforms outcomes :math:`y` to the regression variable :math:`\\eta`. 

854 

855 Parameters 

856 ---------- 

857 y : (n,) array 

858 Outcomes. 

859 hp : str or dict 

860 The hyperparameters to use. If ``'map'`` (default), use the marginal 

861 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

862 the posterior. If a dict, use the given hyperparameters. 

863 rng : numpy.random.Generator, optional 

864 Random number generator, used if ``hp == 'sample'``. 

865 

866 Returns 

867 ------- 

868 eta : (n,) array 

869 Transformed outcomes. 

870 """ 

871 

872 hp = self._gethp(hp, rng) 1bcde

873 return self._from_data(hp, y) 1bcde

874 

875 def to_data(self, eta, *, hp='map', rng=None): 1fabcde

876 """ 

877 Convert the regression variable :math:`\\eta` to outcomes :math:`y`. 

878 

879 Parameters 

880 ---------- 

881 eta : (n,) array 

882 Transformed outcomes. 

883 hp : str or dict 

884 The hyperparameters to use. If ``'map'`` (default), use the marginal 

885 maximum a posteriori. If ``'sample'``, sample hyperparameters from 

886 the posterior. If a dict, use the given hyperparameters. 

887 rng : numpy.random.Generator, optional 

888 Random number generator, used if ``hp == 'sample'``. 

889 

890 Returns 

891 ------- 

892 y : (n,) array 

893 Outcomes. 

894 """ 

895 

896 hp = self._gethp(hp, rng) 1bcde

897 return self._to_data(hp, eta) 1bcde

898 

899 @classmethod 1fabcde

900 def _to_structured(cls, x, *, check_numerical=True): 1fabcde

901 

902 # convert to StructuredArray 

903 if hasattr(x, 'columns'): 1abcde

904 x = _array.StructuredArray.from_dataframe(x) 1a

905 elif hasattr(x, 'to_numpy'): 905 ↛ 906line 905 didn't jump to line 906 because the condition on line 905 was never true1abcde

906 x = _array.StructuredArray.from_dict({ 

907 'f0' if x.name is None else x.name: x.to_numpy() 

908 }) 

909 elif x.dtype.names is None: 1abcde

910 x = _array.unstructured_to_structured(x) 1bcde

911 else: 

912 x = _array.StructuredArray(x) 1a

913 

914 # check fields are numerical, for BART 

915 if check_numerical: 915 ↛ 923line 915 didn't jump to line 923 because the condition on line 915 was always true1abcde

916 assert x.ndim == 1 1abcde

917 assert x.size > len(x.dtype) 1abcde

918 def check_numerical(path, dtype): 1abcde

919 if not numpy.issubdtype(dtype, numpy.number): 919 ↛ 920line 919 didn't jump to line 920 because the condition on line 919 was never true1abcde

920 raise TypeError(f'covariate `{path}` is not numerical') 

921 cls._walk_dtype(x.dtype, check_numerical) 1abcde

922 

923 return x 1abcde

924 

925 @staticmethod 1fabcde

926 def _to_vector(x): 1fabcde

927 if hasattr(x, 'columns'): # dataframe 927 ↛ 928line 927 didn't jump to line 928 because the condition on line 927 was never true1abcde

928 x = x.to_numpy().squeeze(axis=1) 

929 elif hasattr(x, 'to_numpy'): # series (dataframe column) 1abcde

930 x = x.to_numpy() 1a

931 x = jnp.asarray(x) 1abcde

932 if x.ndim != 1: 932 ↛ 933line 932 didn't jump to line 933 because the condition on line 932 was never true1abcde

933 raise ValueError(f'array is not 1d vector, ndim={x.ndim}') 

934 return x 1abcde

935 

936 @classmethod 1fabcde

937 def _walk_dtype(cls, dtype, task, path=None): 1fabcde

938 if dtype.names is None: 1abcde

939 task(path, dtype) 1abcde

940 else: 

941 for name in dtype.names: 1abcde

942 subpath = name if path is None else path + ':' + name 1abcde

943 cls._walk_dtype(dtype[name], task, subpath) 1abcde

944 

945 @staticmethod 1fabcde

946 def _toindices(x, splits): 1fabcde

947 ix = _kernels.BART.indices_from_coord(x, splits) 1abcde

948 dtype = cast(x.dtype, ix.dtype) 1abcde

949 return _array.unstructured_to_structured(ix, dtype=dtype) 1abcde

950 

951 def __repr__(self): 1fabcde

952 

953 with _gvarext.gvar_format(): 1a

954 if hasattr(self.m, 'sdev'): 954 ↛ 955line 954 didn't jump to line 955 because the condition on line 954 was never true1a

955 m = str(self.m) 

956 else: 

957 m = f'{self.m:.3g}' 1a

958 

959 n = self.fit.gpfactorykw['y'].size 1a

960 p_mu = _array._nd(self.fit.gpfactorykw['i_mu']['x'].dtype) 1a

961 x_tau = self.fit.gpfactorykw['i_tau'] 1a

962 x_aux = self.fit.gpfactorykw['x_aux'] 1a

963 

964 out = f"""\ 1a

965Data: 

966 n = {n}""" 

967 

968 if x_tau is None: 968 ↛ 972line 968 didn't jump to line 972 because the condition on line 968 was always true1a

969 out += f""" 1a

970 p = {p_mu}""" 

971 else: 

972 p_tau = _array._nd(x_tau['x'].dtype) 

973 out += f""" 

974 p_mu/tau = {p_mu}, {p_tau}""" 

975 

976 if x_aux is not None: 976 ↛ 977line 976 didn't jump to line 977 because the condition on line 976 was never true1a

977 p_aux = _array._nd(x_aux['x'].dtype) 

978 out += f""" 

979 p_aux = {p_aux}""" 

980 

981 out += f""" 1a

982Hyperparameter posterior: 

983 m = {m} 

984 z_0 = {self.z_0} 

985 alpha_mu/tau = {self.alpha_mu} {self.alpha_tau} 

986 beta_mu/tau = {self.beta_mu} {self.beta_tau} 

987 lambda_mu/tau = {self.lambda_mu} {self.lambda_tau}""" 

988 

989 weights = self.fit.gpfactorykw['weights'] 1a

990 if weights is None: 990 ↛ 995line 990 didn't jump to line 995 because the condition on line 990 was always true1a

991 out += f""" 1a

992 sigma = {self.sigma}""" 

993 

994 else: 

995 weights = numpy.array(weights) # to avoid jax taking over the ops 

996 avgsigma = numpy.sqrt(numpy.mean(self.sigma ** 2 / weights)) 

997 out += f""" 

998 sqrt(mean(sigma^2/w)) = {avgsigma} 

999 sigma = {self.sigma}""" 

1000 

1001 out += """ 1a

1002Meaning of hyperparameters: 

1003 mu(x) = reference outcome level 

1004 tau(x) = effect of the treatment 

1005 z_0 in (0, 1): reference treatment level 

1006 z_0 -> 0: mu is the model of the untreated 

1007 z_0 -> 1: mu is the model of the treated 

1008 alpha in (0, 1) 

1009 alpha -> 0: constant function 

1010 alpha -> 1: no constraints on the function 

1011 beta in (0, ∞) 

1012 beta -> 0: no constraints on the function 

1013 beta -> ∞: no interactions, f(x) = f1(x1) + f2(x2) + ... 

1014 lambda in (0, ∞): standard deviation of function 

1015 lambda small: confident extrapolation 

1016 lambda large: conservative extrapolation 

1017 sigma in (0, ∞): standard deviation of i.i.d. error""" 

1018 

1019 return _utils.top_bottom_rule('BCF', out) 1a

1020 

1021 # TODO print user parameters, applying transformations. Copy the dict and use .pop() to remove the predefined params as they are printed. 

1022 

1023 def _get_transf(self, *, transf, y, weights): 1fabcde

1024 

1025 from_datas = [] 1abcde

1026 to_datas = [] 1abcde

1027 hypers = {} 1abcde

1028 

1029 if transf is None: 1029 ↛ 1030line 1029 didn't jump to line 1030 because the condition on line 1029 was never true1abcde

1030 transf = [] 

1031 elif isinstance(transf, list): 1abcde

1032 name = lambda n: f'transf{i}_{n}' 1abcde

1033 else: 

1034 name = lambda n: n 1bcde

1035 transf = [transf] 1bcde

1036 

1037 for i, tr in enumerate(transf): 1abcde

1038 

1039 hyper = {} 1abcde

1040 

1041 if not isinstance(tr, str): 1041 ↛ 1043line 1041 didn't jump to line 1043 because the condition on line 1041 was never true1abcde

1042 

1043 from_data, to_data = tr 

1044 

1045 elif tr == 'standardize': 1abcde

1046 

1047 if i > 0: 1047 ↛ 1048line 1047 didn't jump to line 1048 because the condition on line 1047 was never true1abcde

1048 warnings.warn('standardization applied after other ' 

1049 'transformations: standardization always uses the ' 

1050 'initial data mean and standard deviation, so it may ' 

1051 'not work as intended') 

1052 

1053 # It's not possible to overcome this limitation if one wants 

1054 # to stick to transformations that act on one point at a 

1055 # time to make them generalizable out of sample. 

1056 

1057 if weights is None: 1057 ↛ 1061line 1057 didn't jump to line 1061 because the condition on line 1057 was always true1abcde

1058 loc = jnp.mean(y) 1abcde

1059 scale = jnp.std(y) 1abcde

1060 else: 

1061 loc = jnp.average(y, weights=weights) 

1062 scale = jnp.sqrt(jnp.average((y - loc) ** 2, weights=weights)) 

1063 

1064 def from_data(hp, y): 1abcde

1065 return (y - loc) / scale 1abcde

1066 def to_data(hp, eta): 1abcde

1067 return loc + scale * eta 1abcde

1068 

1069 elif tr == 'yeojohnson': 1069 ↛ 1078line 1069 didn't jump to line 1078 because the condition on line 1069 was always true1abcde

1070 

1071 def from_data(hp, y): 1abcde

1072 return yeojohnson(y, hp[name('lambda_yj')]) 1abcde

1073 def to_data(hp, eta): 1abcde

1074 return yeojohnson_inverse(eta, hp[name('lambda_yj')]) 1abcde

1075 hyper[name('lambda_yj')] = 2 * copula.beta(2, 2) 1abcde

1076 

1077 else: 

1078 raise KeyError(tr) 

1079 

1080 from_datas.append(from_data) 1abcde

1081 to_datas.append(to_data) 1abcde

1082 hypers.update(hyper) 1abcde

1083 

1084 if transf: 1084 ↛ 1094line 1084 didn't jump to line 1094 because the condition on line 1084 was always true1abcde

1085 def from_data(hp, y): 1abcde

1086 for fd in from_datas: 1abcde

1087 y = fd(hp, y) 1abcde

1088 return y 1abcde

1089 def to_data(hp, eta): 1abcde

1090 for td in reversed(to_datas): 1abcde

1091 eta = td(hp, eta) 1abcde

1092 return eta 1abcde

1093 else: 

1094 from_data = lambda hp, y: y 

1095 to_data = lambda hp, eta: eta 

1096 

1097 from_data_grad = _jaxext.elementwise_grad(from_data, 1) 1abcde

1098 def loss(hp): 1abcde

1099 return -jnp.sum(jnp.log(from_data_grad(hp, y))) 1abcde

1100 

1101 hypers = copula.makedict(hypers) 1abcde

1102 

1103 return from_data, to_data, loss, hypers 1abcde

1104 

1105def yeojohnson(x, lmbda): 1fabcde

1106 """ Yeo-Johnson transformation with lamda != 0, 2 """ 

1107 return jnp.where( 1abcde

1108 x >= 0, 

1109 (jnp.power(x + 1, lmbda) - 1) / lmbda, 

1110 -((jnp.power(-x + 1, 2 - lmbda) - 1) / (2 - lmbda)) 

1111 ) 

1112 

1113 # TODO 

1114 # - rewrite the cases with expm1, log1p, etc. to make them accurate 

1115 # - split the cases into lambda 0/2 

1116 # - make custom_jvps for the singular points to define derivatives w.r.t. 

1117 # lambda even though it does not appear in the expression 

1118 # - add unit tests that check gradients with finite differences 

1119 

1120def yeojohnson_inverse(y, lmbda): 1fabcde

1121 return jnp.where( 1abcde

1122 y >= 0, 

1123 jnp.power(y * lmbda + 1, 1 / lmbda) - 1, 

1124 -jnp.power(-(2 - lmbda) * y + 1, 1 / (2 - lmbda)) + 1 

1125 )