Coverage for src/lsqfitgp/_kernels/_bart.py: 99%

354 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_bart.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1feabcd

21 

22import jax 1feabcd

23from jax import numpy as jnp 1feabcd

24from jax import lax 1feabcd

25from jax.scipy import special as jspecial 1feabcd

26from numpy.lib import recfunctions 1feabcd

27 

28from .. import _jaxext 1feabcd

29from .. import _array 1feabcd

30from .._Kernel import kernel 1feabcd

31 

32@kernel(derivable=False, batchbytes=10e6) 1feabcd

33# TODO maybe batching should be done automatically by GP instead of by the 

34# kernels? But before doing that I need to support batching non-traceable 

35# functions. 

36def _BARTBase(x, y, 1feabcd

37 alpha=0.95, 

38 beta=2, 

39 maxd=2, 

40 gamma=1, 

41 splits=None, 

42 pnt=None, 

43 intercept=True, 

44 weights=None, 

45 reset=None, 

46 indices=False): 

47 """ 

48 BART kernel. 

49 

50 Good default parameters: ``maxd=4, reset=2`` if ``alpha`` and ``beta`` are 

51 kept fixed at the default values, ``maxd=10, reset=[2,4,6,8]`` otherwise. 

52 Derivatives are faster with forward autodiff. 

53  

54 Parameters 

55 ---------- 

56 x, y : arrays 

57 Input points. The array type can be structured, in which case every leaf 

58 field represents a dimension; or unstructured, which specifies a single 

59 dimension. 

60 alpha, beta : scalar 

61 The parameters of the branching probability. 

62 maxd : int 

63 The maximum depth of the trees. 

64 splits : pair of arrays 

65 The first is an int (p,) array containing the number of splitting 

66 points along each dimension, the second has shape (n, p) and contains 

67 the sorted splitting points in each column, filled with high values 

68 after the length. Use `BART.splits_from_coord` to produce them. 

69 gamma : scalar or str 

70 Interpolation coefficient in [0, 1] between a lower and a upper 

71 bound on the infinite maxd limit, or a string 'auto' indicating to 

72 use a formula which depends on alpha, beta, maxd and the number of 

73 covariates, empirically calibrated on maxd from 1 to 3. Default 1 

74 (upper bound). 

75 pnt : (maxd + 1,) array, optional 

76 Nontermination probabilities at depths 0...maxd. If specified, 

77 ``alpha``, ``beta`` and ``maxd`` are ignored. 

78 intercept : bool, default True 

79 The correlation is in [1 - alpha, 1] (or [1 - pnt[0], 1] when using 

80 pnt). If intercept=False, it is rescaled to [0, 1]. 

81 weights : (p,) array, optional 

82 Unnormalized selection probabilities for the covariate axes. If not 

83 specified, all axes have the same probability to be selected for 

84 splitting. 

85 reset : int or sequence of int, optional 

86 List of depths at which the recursion is reset, in the sense that the 

87 function value at a reset depth is evaluated on the initial inputs for 

88 all recursion paths, instead of the modified input handed down by the 

89 recursion. Default none. 

90 indices : bool, default False 

91 If False , the inputs `x`, `y` represent coordinate values. If True, 

92 they are taken to be already the indices of the points in the splitting 

93 grid, as can be obtained with `BART.indices_from_coord`. 

94  

95 Methods 

96 ------- 

97 splits_from_coord 

98 indices_from_coord 

99 correlation 

100  

101 Notes 

102 ----- 

103 This is the covariance function of the latent mean prior of BART (Bayesian 

104 Additive Regression Trees) [1]_ with an upper bound :math:`D` on the depth 

105 of the trees. This prior is the distribution of the function 

106  

107 .. math:: 

108 f(\\mathbf x) = \\lim_{m\\to\\infty} 

109 \\sum_{j=1}^m g(\\mathbf x; T_j, M_j), 

110  

111 where each :math:`g(\\mathbf x; T_j, M_j)` is a decision tree evaluated at 

112 :math:`\\mathbf x`, with structure :math:`T_j` and leaf values :math:`M_j`. 

113 The trees are i.i.d., with the following distribution for :math:`T_j`: for 

114 a node at depth :math:`d`, with :math:`d = 0` for the root, the probability 

115 of not being a leaf, conditional on its existence and its ancestors only, is 

116  

117 .. math:: 

118 P_d = \\alpha (1+d)^{-\\beta}, \\quad 

119 \\alpha \\in [0, 1], \\quad \\beta \\ge 0. 

120  

121 For a non-leaf node, conditional on existence and ancestors, the splitting 

122 variable has uniform distribution amongst the variables with any splitting 

123 points not used by ancestors, and the splitting point has uniform 

124 distribution amongst the available ones. The splitting points are fixed, 

125 tipically from the data. 

126  

127 The distribution of leaves :math:`M_j` is i.i.d. Normal with variance 

128 :math:`1/m`, such that :math:`f(x)` has variance 1. In the limit 

129 :math:`m\\to\\infty`, the distribution of :math:`f(x)` becomes a Gaussian 

130 process. 

131  

132 Since the trees are independent, the covariance function can be computed 

133 for a single tree. Consider two coordinates :math:`x` and :math:`y`, with 

134 :math:`x \\le y`. Let :math:`n^-`, :math:`n^0` and :math:`n^+` be the 

135 number of splitting points respectively before :math:`x`, between 

136 :math:`x`, :math:`y` and after :math:`y`. Next, define :math:`\\mathbf 

137 n^-`, :math:`\\mathbf n^0` and :math:`\\mathbf n^+` as the vectors of such 

138 quantities for each dimension, with a total of :math:`p` dimensions, and 

139 :math:`\\mathbf n = \\mathbf n^- + \\mathbf n^0 + \\mathbf n^+`. Then the 

140 covariance function can be written recursively as 

141  

142 .. math:: 

143 \\newcommand{\\nvecs}{\\mathbf n^-, \\mathbf n^0, \\mathbf n^+} 

144 k(\\mathbf x, \\mathbf y) &= k_0(\\nvecs), \\\\ 

145 k_D(\\nvecs) &= 1 - (1 - \\gamma) P_D, 

146 \\quad \\mathbf n^0 \\ne \\mathbf 0, \\\\ 

147 k_d(\\mathbf 0, \\mathbf 0, \\mathbf 0) &= 1, \\\\ 

148 k_d(\\nvecs) &= 1 - P_d \\Bigg(1 - \\frac1{W(\\mathbf n)} 

149 \\sum_{\\substack{i=1 \\\\ n_i\\ne 0}}^p 

150 \\frac{w_i}{n_i} \\Bigg( \\\\ 

151 &\\qquad \\sum_{k=0}^{n^-_i - 1} 

152 k_{d+1}(\\mathbf n^-_{n^-_i=k}, \\mathbf n^0, \\mathbf n^+) 

153 + {} \\\\ 

154 &\\qquad \\sum_{k=0}^{n^+_i - 1} 

155 k_{d+1}(\\mathbf n^-, \\mathbf n^0, \\mathbf n^+_{n^+_i=k}) 

156 \\Bigg) 

157 \\Bigg), \\quad d < D, \\\\ 

158 W(\\mathbf n) &= \\sum_{\\substack{i=1 \\\\ n_i\\ne 0}}^p w_i. 

159  

160 The introduction of a maximum depth :math:`D` is necessary for 

161 computational feasibility. As :math:`D` increases, the result converges to 

162 the one without depth limit. For :math:`D \\le 2` (the default value), the 

163 covariance is implemented in closed form and takes :math:`O(p)` to compute. 

164 For :math:`D > 2`, the computational complexity grows exponentially as 

165 :math:`O(p(\\bar np)^{D-2})`, where :math:`\\bar n` is the average number of splitting 

166 points along a dimension. 

167  

168 In the maximum allowed depth is 1, i.e., either :math:`D = 1` or 

169 :math:`\\beta\\to\\infty`, the kernel assumes the simple form 

170  

171 .. math:: 

172 k(\\mathbf x, \\mathbf y) &= 1 - P_0 \\left( 

173 1 - Q + \\frac Q{W(\\mathbf n)} 

174 \\sum_{\\substack{i=1 \\\\ n_i\\ne 0}}^p w_i 

175 \\frac{n^0_i}{n_i} \\right), \\\\ 

176 Q &= \\begin{cases} 

177 1 - (1 - \\gamma) P_1 & \\mathbf n^0 \\ne \\mathbf 0, \\\\ 

178 1 & \\mathbf n^0 = \\mathbf 0, 

179 \\end{cases} 

180  

181 which is separable along dimensions, i.e., it has no interactions. 

182  

183 References 

184 ---------- 

185 .. [1] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

186 Bayesian additive regression trees," The Annals of Applied Statistics, 

187 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

188 """ 

189 

190 splits = BART._check_splits(splits, indices) 1eabcd

191 if not x.dtype.names: 1eabcd

192 x = x[..., None] 1abcd

193 if not y.dtype.names: 1eabcd

194 y = y[..., None] 1abcd

195 if indices: 1eabcd

196 ix = BART._check_x(x) 1eabcd

197 iy = BART._check_x(y) 1eabcd

198 else: 

199 ix = BART._indices_from_coord(x, splits) 1abcd

200 iy = BART._indices_from_coord(y, splits) 1abcd

201 return BART.correlation( 1eabcd

202 splits[0], ix, iy, 

203 pnt=pnt, alpha=alpha, beta=beta, gamma=gamma, maxd=maxd, 

204 intercept=intercept, weights=weights, reset=reset, altinput=True, 

205 ) 

206 

207 # TODO 

208 # - make gamma='auto' depend on maxd and reset with a dictionary, error 

209 # if not specified 

210 # - do not require to specify splitting points if using indices 

211 

212class BART(_BARTBase): 1feabcd

213 

214 __doc__ = _BARTBase.__doc__ 1feabcd

215 

216 @classmethod 1feabcd

217 def splits_from_coord(cls, x): 1feabcd

218 """ 

219 Generate splitting points from data. 

220  

221 Parameters 

222 ---------- 

223 x : array of numbers 

224 The data. Can be passed in two formats: 1) a structured array where 

225 each leaf field represents a dimension, 2) a normal array where the 

226 last axis runs over dimensions. In the structured case, each 

227 index in any shaped field is a different dimension. 

228  

229 Returns 

230 ------- 

231 length : int (p,) array 

232 The number of splitting points along each of ``p`` dimensions. 

233 splits : (n, p) array 

234 Each column contains the sorted splitting points along a dimension. 

235 The splitting points are the midpoints between consecutive values 

236 appearing in `x` for that dimension. Column ``splits[:, i]`` 

237 contains splitting points only up to ``length[i]``, while afterward 

238 it is filled with a very large value. 

239  

240 """ 

241 x = cls._check_x(x) 1eabcd

242 return cls._splits_from_coord(x) 1eabcd

243 

244 # TODO options like BayesTree, i.e., use an evenly spaced range 

245 # instead of quantilizing, and set a maximum number of splits. Use the 

246 # same parameter names as BayesTree::bart, but change the defaults. 

247 

248 @staticmethod 1feabcd

249 @jax.jit 1feabcd

250 def _splits_from_coord(x): 1feabcd

251 """ 

252 Jitted implementation of splits_from_coord. Applying jit avoids the 

253 recompilation in lax.scan each time the method is called, and 

254 splits_from_coord can not be jitted directly because x could be a numpy 

255 structured array. 

256 """ 

257 x = x.reshape(-1, x.shape[-1]) if x.size else x.reshape(1, x.shape[-1]) 1eabcd

258 if jnp.issubdtype(x.dtype, jnp.inexact): 1eabcd

259 info = jnp.finfo 1eabcd

260 else: 

261 info = jnp.iinfo 1abcd

262 fill = info(x.dtype).max 1eabcd

263 def loop(_, xi): 1eabcd

264 u = jnp.unique(xi, size=xi.size, fill_value=fill) 1eabcd

265 m = jnp.where(u[1:] < fill, (u[1:] + u[:-1]) / 2, fill) 1eabcd

266 l = jnp.searchsorted(m, fill) 1eabcd

267 return _, (l, m) 1eabcd

268 _, (length, midpoints) = lax.scan(loop, None, x.T) 1eabcd

269 return length, midpoints.T 1eabcd

270 

271 @classmethod 1feabcd

272 def indices_from_coord(cls, x, splits): 1feabcd

273 """ 

274 Convert coordinates to indices w.r.t. splitting points. 

275  

276 Parameters 

277 ---------- 

278 x : array of numbers 

279 The coordinates. Can be passed in two formats: 1) a structured 

280 array where each leaf field represents a dimension, 2) a normal 

281 array where the last axis runs over dimensions. In the structured 

282 case, each index in any shaped field is a different dimension. 

283 splits : pair of arrays 

284 The first is an int (p,) array containing the number of splitting 

285 points along each dimension, the second has shape (n, p) and 

286 contains the sorted splitting points in each column, filled with 

287 high values after the length. 

288  

289 Returns 

290 ------- 

291 ix : int array 

292 An array with the same shape as ``x``, unless ``x`` is a structured 

293 array, in which case the last axis of ``ix`` is the flattened version 

294 of the structured type. ``ix`` contains indices mapping ``x`` to 

295 positions between splitting points along each coordinate, with the 

296 following convention: index 0 means before the first split, index 

297 i > 0 means between split i - 1 and split i. 

298  

299 """ 

300 splits = cls._check_splits(splits, False) 1eabcd

301 return cls._indices_from_coord(x, splits) 1eabcd

302 

303 @classmethod 1feabcd

304 def _indices_from_coord(cls, x, checked_splits): 1feabcd

305 x = cls._check_x(x) 1eabcd

306 if x.shape[-1] != checked_splits[0].size: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true1eabcd

307 raise ValueError(f'splitting grid is for {checked_splits[0].size} ' 

308 f'dimensions, found {x.shape[-1]}') 

309 return cls._searchsorted_vectorized(checked_splits[1], x) 1eabcd

310 

311 @classmethod 1feabcd

312 def correlation(cls, 1feabcd

313 splitsbefore_or_totalsplits, 

314 splitsbetween_or_index1, 

315 splitsafter_or_index2, 

316 *, 

317 alpha=0.95, 

318 beta=2, 

319 gamma=1, 

320 maxd=2, 

321 debug=False, 

322 pnt=None, 

323 intercept=True, 

324 weights=None, 

325 reset=None, 

326 altinput=False): 

327 """ 

328 Compute the BART prior correlation between two points. 

329 

330 Apart from arguments ``maxd``, ``debug`` and ``reset``, this method is fully 

331 vectorized. 

332  

333 Parameters 

334 ---------- 

335 splitsbefore_or_totalsplits : int (p,) array 

336 The number of splitting points less than the two points, separately 

337 along each coordinate, or the total number of splits if ``altinput``. 

338 splitsbetween_or_index1 : int (p,) array 

339 The number of splitting points between the two points, separately 

340 along each coordinate, or the index in the splitting bins of the 

341 first point if ``altinput``, where 0 means to the left of the leftmost 

342 splitting point. 

343 splitsafter_or_index2 : int (p,) array 

344 The number of splitting points greater than the two points, 

345 separately along each coordinate, or the index in the splitting bins 

346 of the second point if ``altinput``. 

347 debug : bool 

348 If True, disable shortcuts in the tree recursion. Default False. 

349 altinput : bool 

350 If True, take as input the indices in the splitting bins of the 

351 points instead of the counts of splitting points separating them, 

352 and use a different implementation optimized for that case. Default 

353 False. The `BART` kernel uses ``altinput=True``. 

354 Other parameters : 

355 See `BART`. 

356 

357 Returns 

358 ------- 

359 corr : scalar 

360 The prior correlation. 

361 """ 

362 

363 # check splitting indices are integers 

364 splitsbefore_or_totalsplits = jnp.asarray(splitsbefore_or_totalsplits) 1eabcd

365 splitsbetween_or_index1 = jnp.asarray(splitsbetween_or_index1) 1eabcd

366 splitsafter_or_index2 = jnp.asarray(splitsafter_or_index2) 1eabcd

367 assert jnp.issubdtype(splitsbefore_or_totalsplits.dtype, jnp.integer) 1eabcd

368 assert jnp.issubdtype(splitsbetween_or_index1.dtype, jnp.integer) 1eabcd

369 assert jnp.issubdtype(splitsafter_or_index2.dtype, jnp.integer) 1eabcd

370 

371 # check splitting indices 

372 with _jaxext.skipifabstract(): 1eabcd

373 assert jnp.all(splitsbefore_or_totalsplits >= 0), 'splitting counts must be nonnegative' 1eabcd

374 if altinput: 1eabcd

375 assert jnp.all((0 <= splitsbetween_or_index1) & (splitsbetween_or_index1 <= splitsbefore_or_totalsplits)), 'splitting index must be in [0, n]' 1eabcd

376 assert jnp.all((0 <= splitsafter_or_index2) & (splitsafter_or_index2 <= splitsbefore_or_totalsplits)), 'splitting index must be in [0, n]' 1eabcd

377 else: 

378 assert jnp.all(splitsbetween_or_index1 >= 0), 'splitting counts must be nonnegative' 1abcd

379 assert jnp.all(splitsafter_or_index2 >= 0), 'splitting counts must be nonnegative' 1abcd

380 

381 # get splitting probabilities 

382 if pnt is None: 1eabcd

383 assert maxd == int(maxd) and maxd >= 0, maxd 1eabcd

384 alpha = jnp.asarray(alpha) 1eabcd

385 beta = jnp.asarray(beta) 1eabcd

386 with _jaxext.skipifabstract(): 1eabcd

387 assert jnp.all((0 <= alpha) & (alpha <= 1)), 'alpha must be in [0, 1]' 1eabcd

388 assert jnp.all(beta >= 0), 'beta must be in [0, inf)' 1eabcd

389 d = jnp.arange(maxd + 1) 1eabcd

390 alpha = alpha[..., None] 1eabcd

391 beta = beta[..., None] 1eabcd

392 pnt = alpha / (1 + d) ** beta 1eabcd

393 else: 

394 pnt = jnp.asarray(pnt) 1abcd

395 

396 # get covariate weights 

397 if weights is None: 1eabcd

398 weights = jnp.ones(splitsbefore_or_totalsplits.shape[-1], pnt.dtype) 1eabcd

399 else: 

400 weights = jnp.asarray(weights) 1abcd

401 

402 # get interpolation coefficients 

403 if isinstance(gamma, str): 1eabcd

404 if gamma == 'auto': 1abcd

405 assert reset is None and 1 <= pnt.shape[-1] - 1 <= 3 1abcd

406 p = weights.shape[-1] 1abcd

407 gamma = cls._gamma(p, pnt) 1abcd

408 else: 

409 raise KeyError(gamma) 1abcd

410 else: 

411 gamma = jnp.asarray(gamma) 1eabcd

412 

413 # check values are in range 

414 with _jaxext.skipifabstract(): 1eabcd

415 assert jnp.all((0 <= gamma) & (gamma <= 1)), 'gamma must be in [0, 1]' 1eabcd

416 assert jnp.all((0 <= pnt) & (pnt <= 1)), 'pnt must be in [0, 1]' 1eabcd

417 assert jnp.all(weights >= 0), 'weights must be in [0, inf)' 1eabcd

418 

419 # set first splitting probability to 1 to remove flat baseline (keep 

420 # last!) 

421 if not intercept: 1eabcd

422 pnt = pnt.at[..., 0].set(1) 1eabcd

423 

424 # expand and check recursion reset depths 

425 if reset is None: 1eabcd

426 reset = [] 1abcd

427 if not hasattr(reset, '__len__'): 1eabcd

428 reset = [reset] 1abcd

429 reset = [0] + list(reset) + [pnt.shape[-1] - 1] 1eabcd

430 for i, j in zip(reset, reset[1:]): 1eabcd

431 assert int(j) == j and i <= j, (i, j) 1eabcd

432 

433 # convert reset depths list to brackets with repetition 

434 brackets_norep = list(zip(reset, reset[1:])) 1eabcd

435 brackets = [brackets_norep[0] + (1,)] 1eabcd

436 for t, b in brackets_norep[1:]: 1eabcd

437 lt, lb, lr = brackets[-1] 1eabcd

438 if altinput and not debug and lr * (b - t) == lb - lt and b - t <= 2: 1eabcd

439 brackets[-1] = lt, b, lr + 1 1eabcd

440 else: 

441 brackets.append((t, b, 1)) 1abcd

442 

443 # call recursive function for each recursion slice 

444 corr = gamma 1eabcd

445 for t, b, repeat in reversed(brackets): 1eabcd

446 probs = pnt[..., t:b + 1] 1eabcd

447 if t > 0: 1eabcd

448 probs = probs.at[..., 0].set(1) 1abcd

449 if repeat > 1: 1eabcd

450 head = probs[..., 0:1] 1eabcd

451 one = jnp.ones_like(head) 1eabcd

452 probs = jnp.concatenate(sum(reversed([ 1eabcd

453 [head if i == 0 else one, p] 

454 for i, p in enumerate(jnp.split(probs[..., 1:], repeat, axis=-1)) 

455 ]), start=[]), axis=-1) 

456 else: 

457 repeat = None 1abcd

458 corr = cls._correlation_vectorized( 1eabcd

459 splitsbefore_or_totalsplits, 

460 splitsbetween_or_index1, 

461 splitsafter_or_index2, 

462 probs, corr, weights, 

463 debug, altinput, repeat, 

464 ) 

465 return corr 1eabcd

466 

467 # TODO public method to compute pnt 

468 

469 @staticmethod 1feabcd

470 def _gamma(p, pnt): 1feabcd

471 # gamma(alpha, beta, maxd) = 

472 # = (gamma_0 - gamma_d maxd) (1 - alpha^s 2^(-t beta)) = 

473 # = (gamma_0 - gamma_d maxd) (1 - P0^s-t P1^t) 

474 

475 gamma_0 = 0.611 + 0.021 * jnp.exp(-1.3 * (p - 1)) 1abcd

476 gamma_d = -0.0034 + 0.084 * jnp.exp(-2.02 * (p - 1)) 1abcd

477 s = 2.03 - 0.69 * jnp.exp(-0.72 * (p - 1)) 1abcd

478 t = 4.01 - 1.49 * jnp.exp(-0.77 * (p - 1)) 1abcd

479 

480 maxd = pnt.shape[-1] - 1 1abcd

481 floor = jnp.clip(gamma_0 - gamma_d * maxd, 0, 1) 1abcd

482 

483 P0 = pnt[..., 0] 1abcd

484 P1 = jnp.minimum(P0, pnt[..., 1]) 1abcd

485 corner = jnp.where(P0, 1 - P0 ** (s - t) * P1 ** t, 1) 1abcd

486 

487 return floor * corner 1abcd

488 

489 # TODO make this public? 

490 

491 @staticmethod 1feabcd

492 def _check_x(x): 1feabcd

493 x = _array.asarray(x) 1eabcd

494 if x.dtype.names: 1eabcd

495 x = recfunctions.structured_to_unstructured(x) 1eabcd

496 return x 1eabcd

497 

498 @staticmethod 1feabcd

499 def _check_splits(splits, indices): 1feabcd

500 l, s = splits 1eabcd

501 l = jnp.asarray(l) 1eabcd

502 assert l.ndim == 1 1eabcd

503 if not indices: 1eabcd

504 s = jnp.asarray(s) 1eabcd

505 assert 1 <= s.ndim <= 2 1eabcd

506 if s.ndim == 1: 1eabcd

507 s = s[:, None] 1abcd

508 assert l.size == s.shape[1] 1eabcd

509 with _jaxext.skipifabstract(): 1eabcd

510 assert jnp.all((0 <= l) & (l <= s.shape[0])), 'length out of bounds' 1eabcd

511 if not indices: 1eabcd

512 assert jnp.all(jnp.sort(s, axis=0) == s), 'unsorted splitting points' 1eabcd

513 return l, s 1eabcd

514 

515 @staticmethod 1feabcd

516 @functools.partial(jax.jit, static_argnames=('side',)) 1feabcd

517 def _searchsorted_vectorized(A, V, **kw): 1feabcd

518 """ 

519 A : (n, p) 

520 V : (..., p) 

521 out : (..., p) 

522 """ 

523 def loop(_, av): 1eabcd

524 return _, jnp.searchsorted(*av, **kw) 1eabcd

525 _, out = lax.scan(loop, None, (A.T, V.T)) 1eabcd

526 return out.T 1eabcd

527 

528 @classmethod 1feabcd

529 @functools.partial(jax.jit, static_argnums=(0, 7)) 1feabcd

530 def _correlation_old(cls, nminus, n0, nplus, pnt, gamma, w, debug): 1feabcd

531 """ old version, kept around for cross-checking """ 

532 

533 assert nminus.shape == n0.shape == nplus.shape == w.shape 1abcd

534 assert nminus.ndim == 1 and nminus.size >= 0 1abcd

535 assert pnt.ndim == 1 and pnt.size > 0 1abcd

536 # TODO repeat this shape checks in BART.correlation such that the 

537 # error messages are user-legible 

538 

539 # optimization to avoid looping over ignored axes 

540 nminus = jnp.where(w, nminus, 0) 1abcd

541 n0 = jnp.where(w, n0, 0) 1abcd

542 nplus = jnp.where(w, nplus, 0) 1abcd

543 

544 float_type = _jaxext.float_type(pnt, gamma, w) 1abcd

545 

546 if nminus.size == 0: 1abcd

547 return jnp.array(1, float_type) 1abcd

548 

549 anyn0 = jnp.any(jnp.logical_and(n0, w)) 1abcd

550 

551 if pnt.size == 1: 1abcd

552 return jnp.where(anyn0, 1 - (1 - gamma) * pnt[0], 1) 1abcd

553 

554 nout = nminus + nplus 1abcd

555 n = nout + n0 1abcd

556 Wn = jnp.sum(jnp.where(n, w, 0)) # <-- @ 1abcd

557 

558 if pnt.size == 2 and not debug: 1abcd

559 Q = 1 - (1 - gamma) * pnt[1] 1abcd

560 sump = Q * jnp.sum(jnp.where(n, w * nout / n, 0)) # <-- @ 1abcd

561 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd

562 

563 if pnt.size == 3 and not debug: 1abcd

564 Q = 1 - (1 - gamma) * pnt[2] 1abcd

565 s = w * nout / n 1abcd

566 S = jnp.sum(jnp.where(n, s, 0)) # <-- @ 1abcd

567 t = w * n0 / n 1abcd

568 psin = jspecial.digamma(n.astype(float_type)) 1abcd

569 def terms(nminus, nplus): 1abcd

570 nminus0 = nminus + n0 1abcd

571 Wnmod = Wn - jnp.where(nminus0, 0, w) 1abcd

572 frac = jnp.where(nminus0, w * nminus / nminus0, 0) 1abcd

573 terms1 = (S - s + frac) / Wnmod 1abcd

574 psi1nminus0 = jspecial.digamma((1 + nminus0).astype(float_type)) 1abcd

575 terms2 = ((nplus - 1) * (S + t) - w * n0 * (psin - psi1nminus0)) / Wn 1abcd

576 return jnp.where(nplus, terms1 + terms2, 0) 1abcd

577 tplus = terms(nminus, nplus) 1abcd

578 tminus = terms(nplus, nminus) 1abcd

579 tall = jnp.where(n, w * (tplus + tminus) / n, 0) 1abcd

580 sump = (1 - pnt[1]) * S + pnt[1] * Q * jnp.sum(tall) # <-- @ 1abcd

581 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd

582 

583 # TODO the pnt.size == 3 calculation is probably less accurate than 

584 # the recursive one, see comparison limits > 30 ULP in test_bart.py 

585 

586 p = len(nminus) 1abcd

587 

588 val = (0., nminus, n0, nplus) 1abcd

589 def loop(i, val): 1abcd

590 sump, nminus, n0, nplus = val 1abcd

591 

592 nminusi = nminus[i] 1abcd

593 n0i = n0[i] 1abcd

594 nplusi = nplus[i] 1abcd

595 ni = nminusi + n0i + nplusi 1abcd

596 

597 val = (0., nminus, n0, nplus, i, nminusi) 1abcd

598 def loop(k, val): 1abcd

599 sumn, nminus, n0, nplus, i, nminusi = val 1abcd

600 

601 # here I use the fact that .at[].set won't set the value if the 

602 # index is out of bounds 

603 nminus = nminus.at[jnp.where(k < nminusi, i, i + p)].set(k) 1abcd

604 nplus = nplus.at[jnp.where(k >= nminusi, i, i + p)].set(k - nminusi) 1abcd

605 

606 sumn += cls._correlation_old(nminus, n0, nplus, pnt[1:], gamma, w, debug) 1abcd

607 

608 nminus = nminus.at[i].set(nminusi) 1abcd

609 nplus = nplus.at[i].set(nplusi) 1abcd

610 

611 return sumn, nminus, n0, nplus, i, nminusi 1abcd

612 

613 # if ni == 0 I skip recursion by passing 0 as iteration end 

614 end = jnp.where(ni, nminusi + nplusi, 0) 1abcd

615 start = jnp.zeros_like(end) 1abcd

616 sumn, nminus, n0, nplus, _, _ = lax.fori_loop(start, end, loop, val) 1abcd

617 

618 sump += jnp.where(ni, w[i] * sumn / ni, 0) 1abcd

619 

620 return sump, nminus, n0, nplus 1abcd

621 

622 # skip summation if all(n0 == 0) 

623 end = jnp.where(anyn0, p, 0) 1abcd

624 sump, _, _, _ = lax.fori_loop(0, end, loop, val) 1abcd

625 

626 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd

627 

628 @staticmethod 1feabcd

629 def _scan_but_first(f, init, xs): 1feabcd

630 """ lax.scan, but execute separately the first cycle. The point is that 

631 I use it when the first cycle works on smaller arrays due to 

632 broadcasting. """ 

633 assert isinstance(xs, jnp.ndarray) 1eabcd

634 assert len(xs) > 0 1eabcd

635 init, out = f(init, xs[0]) 1eabcd

636 assert out is None 1eabcd

637 if len(xs) == 1: 1eabcd

638 return init, out 1abcd

639 elif len(xs) == 2: 1eabcd

640 return f(init, xs[1]) 1abcd

641 else: 

642 return lax.scan(f, init, xs[1:]) 1eabcd

643 

644 @classmethod 1feabcd

645 @functools.partial(jax.jit, static_argnums=(0, 7, 8)) 1feabcd

646 def _correlation(cls, n, ix, iy, pnt, gamma, w, debug, repeat): 1feabcd

647 # this implementation is optimized assuming that the shapes are as 

648 # follows: 

649 # n (p,) 

650 # ix (n, 1, p) 

651 # iy (1, n, p) 

652 # pnt (d,) 

653 # gamma () or (n, n) 

654 # w (p,) 

655 

656 assert n.ndim == 1 1eabcd

657 assert n.shape == ix.shape == iy.shape == w.shape 1eabcd

658 assert pnt.ndim == 1 and pnt.size > 0 1eabcd

659 assert gamma.ndim == 0 1eabcd

660 # TODO repeat this shape checks in BART.correlation such that the 

661 # error messages are user-legible 

662 

663 # check the strict conditions under which `repeat` is implemented 

664 if repeat is not None: 1eabcd

665 assert ( 1a

666 not debug 

667 and repeat > 0 

668 and pnt.size % repeat == 0 

669 and pnt.size // repeat <= 3 

670 ) 

671 else: 

672 repeat = 1 1abcd

673 

674 # infer float type from float arguments 

675 flt = _jaxext.float_type(pnt, gamma, w) 1eabcd

676 

677 # no covariates, always return 1 

678 if n.size == 0: 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true1eabcd

679 return jnp.array(1, flt) 

680 

681 # pre-cast all floats to the common type, to avoid unwanted float32 

682 # calculations in mixed float-integer operations 

683 pnt = pnt.astype(flt) 1eabcd

684 gamma = gamma.astype(flt) 1eabcd

685 w = w.astype(flt) 1eabcd

686 

687 # ignore zero-weight axes 

688 n = jnp.where(w, n, 0) 1eabcd

689 ix = jnp.where(w, ix, 0) 1eabcd

690 iy = jnp.where(w, iy, 0) 1eabcd

691 

692 # check if the points coincide 

693 seed = jnp.uint64(16132933535611723338) 1eabcd

694 hx = _jaxext.fasthash64(ix, seed) 1eabcd

695 hy = _jaxext.fasthash64(iy, seed) 1eabcd

696 anyn0 = hx != hy 1eabcd

697 # no hash collision checking, it would be branchless because of vmap, 

698 # the probability of collision building a nxn matrix with n=10000 is 

699 # -expm1(10000**2 * log1p(-1/2**64)) = 5e-12. 

700 

701 # base case of the recursion, no dependence on points apart from the 

702 # case when they are equal 

703 if pnt.size // repeat == 1: 1eabcd

704 def loop(carry, pnt): 1abcd

705 anyn0, gamma = carry 1abcd

706 gamma = jnp.where(anyn0, 1 - (1 - gamma) * pnt[0], 1) 1abcd

707 return (anyn0, gamma), None 1abcd

708 (_, gamma), _ = cls._scan_but_first(loop, (anyn0, gamma), pnt.reshape(repeat, -1)) 1abcd

709 return gamma 1abcd

710 

711 # normalization for axes weights 

712 Wn = jnp.sum(jnp.where(n, w, 0)) 1eabcd

713 

714 # shortcut for the last two levels of the recursion 

715 if pnt.size // repeat == 2 and not debug: 1eabcd

716 n0 = jnp.abs(ix - iy) 1abcd

717 sum_term = jnp.where(n, w / n, 0) @ n0 1abcd

718 def loop(carry, pnt): 1abcd

719 anyn0, Wn, sum_term, gamma = carry 1abcd

720 Q = 1 - pnt[1] + gamma * pnt[1] 1abcd

721 P0 = pnt[0] 1abcd

722 result = 1 - P0 + Q * (P0 - P0 / Wn * sum_term) 1abcd

723 gamma = jnp.where(anyn0, result, 1) 1abcd

724 return (anyn0, Wn, sum_term, gamma), None 1abcd

725 (_, _, _, gamma), _ = cls._scan_but_first(loop, (anyn0, Wn, sum_term, gamma), pnt.reshape(repeat, -1)) 1abcd

726 return gamma 1abcd

727 

728 # convert to alternative format 

729 xlty = ix < iy 1eabcd

730 minxy = jnp.where(xlty, ix, iy) 1eabcd

731 maxxy = jnp.where(xlty, iy, ix) 1eabcd

732 n0 = maxxy - minxy 1eabcd

733 

734 # shortcut for the last three levels of the recursion 

735 if pnt.size // repeat == 3 and not debug: 1eabcd

736 nminus0 = maxxy 1eabcd

737 nplus0 = n - minxy 1eabcd

738 nout = n - n0 1eabcd

739 

740 inv_Wn = 1 / Wn 1eabcd

741 inv_Wnmod = 1 / (Wn - jnp.where(n, w, 0)) 1eabcd

742 inv_Wnminus = jnp.where(nplus0, inv_Wn, inv_Wnmod) 1eabcd

743 inv_Wnplus = jnp.where(nminus0, inv_Wn, inv_Wnmod) 1eabcd

744 wn = jnp.where(n, w / n, 0) 1eabcd

745 S = wn @ nout 1eabcd

746 

747 t = wn * n0 1eabcd

748 terms1 = (S + t) * (inv_Wnminus + inv_Wnplus + inv_Wn * (nout - 2)) 1eabcd

749 

750 terms2 = jnp.where( nplus0, w * inv_Wn * n0 / nplus0, w * inv_Wnmod) 1eabcd

751 terms2 += jnp.where(nminus0, w * inv_Wn * n0 / nminus0, w * inv_Wnmod) 1eabcd

752 

753 psin = jspecial.digamma(jnp.where(n, n, 1).astype(flt)) 1eabcd

754 psiminus = jnp.where(xlty, 1eabcd

755 jspecial.digamma((1 + iy).astype(flt)), 

756 jspecial.digamma((1 + ix).astype(flt)), 

757 ) 

758 psiplus = jnp.where(xlty, 1eabcd

759 jspecial.digamma((1 + n - ix).astype(flt)), 

760 jspecial.digamma((1 + n - iy).astype(flt)), 

761 ) 

762 terms3 = w * inv_Wn * n0 * (2 * psin - psiminus - psiplus) 1eabcd

763 

764 terms = terms1 - terms2 - terms3 1eabcd

765 sumi = wn @ terms 1eabcd

766 

767 def loop(carry, pnt): 1eabcd

768 anyn0, inv_Wn, S, sumi, gamma = carry 1eabcd

769 Q = 1 + pnt[2] * (gamma - 1) 1eabcd

770 sump = S + pnt[1] * (Q * sumi - S) 1eabcd

771 result = 1 + pnt[0] * (inv_Wn * sump - 1) 1eabcd

772 gamma = jnp.where(anyn0, result, 1) 1eabcd

773 return (anyn0, inv_Wn, S, sumi, gamma), None 1eabcd

774 (_, _, _, _, gamma), _ = cls._scan_but_first(loop, (anyn0, inv_Wn, S, sumi, gamma), pnt.reshape(repeat, -1)) 1eabcd

775 return gamma 1eabcd

776 

777 # finish conversion to alternative format 

778 nminus = minxy 1abcd

779 nplus = n - maxxy 1abcd

780 p = len(nminus) 1abcd

781 del ix, iy, maxxy, minxy 1abcd

782 

783 val = (0., nminus, n0, nplus) 1abcd

784 def loop(i, val): 1abcd

785 sump, nminus, n0, nplus = val 1abcd

786 

787 nminusi = nminus[i] 1abcd

788 n0i = n0[i] 1abcd

789 nplusi = nplus[i] 1abcd

790 ni = nminusi + n0i + nplusi 1abcd

791 

792 val = (0., nminus, n0, nplus, i, nminusi) 1abcd

793 def loop(k, val): 1abcd

794 sumn, nminus, n0, nplus, i, nminusi = val 1abcd

795 

796 # here I use the fact that .at[].set won't set the value if the 

797 # index is out of bounds 

798 nminus = nminus.at[jnp.where(k < nminusi, i, i + p)].set(k) 1abcd

799 nplus = nplus.at[jnp.where(k >= nminusi, i, i + p)].set(k - nminusi) 1abcd

800 

801 n = nminus + n0 + nplus 1abcd

802 ix = nminus 1abcd

803 iy = nminus + n0 1abcd

804 sumn += cls._correlation(n, ix, iy, pnt[1:], gamma, w, debug, None) 1abcd

805 

806 nminus = nminus.at[i].set(nminusi) 1abcd

807 nplus = nplus.at[i].set(nplusi) 1abcd

808 

809 return sumn, nminus, n0, nplus, i, nminusi 1abcd

810 

811 # if ni == 0 I skip recursion by passing 0 as iteration end 

812 end = jnp.where(ni, nminusi + nplusi, 0) 1abcd

813 start = jnp.zeros_like(end) 1abcd

814 sumn, nminus, n0, nplus, _, _ = lax.fori_loop(start, end, loop, val) 1abcd

815 

816 sump += jnp.where(ni, w[i] * sumn / ni, 0) 1abcd

817 

818 return sump, nminus, n0, nplus 1abcd

819 

820 # skip summation if all(n0 == 0) 

821 end = jnp.where(anyn0, p, 0) 1abcd

822 sump, _, _, _ = lax.fori_loop(0, end, loop, val) 1abcd

823 

824 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd

825 

826 @classmethod 1feabcd

827 @functools.partial(jnp.vectorize, excluded=(0, 7, 8, 9), signature='(p),(p),(p),(d),(),(p)->()') 1feabcd

828 def _correlation_vectorized(cls, nminus_or_n, n0_or_ix, nplus_or_iy, pnt, gamma, w, debug, altinput, repeat): 1feabcd

829 if altinput: 1eabcd

830 func = lambda *args: cls._correlation(*args, repeat) 1eabcd

831 else: 

832 func = cls._correlation_old 1abcd

833 return func(nminus_or_n, n0_or_ix, nplus_or_iy, pnt, gamma, w, bool(debug)) 1eabcd