Coverage for src/lsqfitgp/_kernels/_bart.py: 99%
354 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_kernels/_bart.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1feabcd
22import jax 1feabcd
23from jax import numpy as jnp 1feabcd
24from jax import lax 1feabcd
25from jax.scipy import special as jspecial 1feabcd
26from numpy.lib import recfunctions 1feabcd
28from .. import _jaxext 1feabcd
29from .. import _array 1feabcd
30from .._Kernel import kernel 1feabcd
32@kernel(derivable=False, batchbytes=10e6) 1feabcd
33# TODO maybe batching should be done automatically by GP instead of by the
34# kernels? But before doing that I need to support batching non-traceable
35# functions.
36def _BARTBase(x, y, 1feabcd
37 alpha=0.95,
38 beta=2,
39 maxd=2,
40 gamma=1,
41 splits=None,
42 pnt=None,
43 intercept=True,
44 weights=None,
45 reset=None,
46 indices=False):
47 """
48 BART kernel.
50 Good default parameters: ``maxd=4, reset=2`` if ``alpha`` and ``beta`` are
51 kept fixed at the default values, ``maxd=10, reset=[2,4,6,8]`` otherwise.
52 Derivatives are faster with forward autodiff.
54 Parameters
55 ----------
56 x, y : arrays
57 Input points. The array type can be structured, in which case every leaf
58 field represents a dimension; or unstructured, which specifies a single
59 dimension.
60 alpha, beta : scalar
61 The parameters of the branching probability.
62 maxd : int
63 The maximum depth of the trees.
64 splits : pair of arrays
65 The first is an int (p,) array containing the number of splitting
66 points along each dimension, the second has shape (n, p) and contains
67 the sorted splitting points in each column, filled with high values
68 after the length. Use `BART.splits_from_coord` to produce them.
69 gamma : scalar or str
70 Interpolation coefficient in [0, 1] between a lower and a upper
71 bound on the infinite maxd limit, or a string 'auto' indicating to
72 use a formula which depends on alpha, beta, maxd and the number of
73 covariates, empirically calibrated on maxd from 1 to 3. Default 1
74 (upper bound).
75 pnt : (maxd + 1,) array, optional
76 Nontermination probabilities at depths 0...maxd. If specified,
77 ``alpha``, ``beta`` and ``maxd`` are ignored.
78 intercept : bool, default True
79 The correlation is in [1 - alpha, 1] (or [1 - pnt[0], 1] when using
80 pnt). If intercept=False, it is rescaled to [0, 1].
81 weights : (p,) array, optional
82 Unnormalized selection probabilities for the covariate axes. If not
83 specified, all axes have the same probability to be selected for
84 splitting.
85 reset : int or sequence of int, optional
86 List of depths at which the recursion is reset, in the sense that the
87 function value at a reset depth is evaluated on the initial inputs for
88 all recursion paths, instead of the modified input handed down by the
89 recursion. Default none.
90 indices : bool, default False
91 If False , the inputs `x`, `y` represent coordinate values. If True,
92 they are taken to be already the indices of the points in the splitting
93 grid, as can be obtained with `BART.indices_from_coord`.
95 Methods
96 -------
97 splits_from_coord
98 indices_from_coord
99 correlation
101 Notes
102 -----
103 This is the covariance function of the latent mean prior of BART (Bayesian
104 Additive Regression Trees) [1]_ with an upper bound :math:`D` on the depth
105 of the trees. This prior is the distribution of the function
107 .. math::
108 f(\\mathbf x) = \\lim_{m\\to\\infty}
109 \\sum_{j=1}^m g(\\mathbf x; T_j, M_j),
111 where each :math:`g(\\mathbf x; T_j, M_j)` is a decision tree evaluated at
112 :math:`\\mathbf x`, with structure :math:`T_j` and leaf values :math:`M_j`.
113 The trees are i.i.d., with the following distribution for :math:`T_j`: for
114 a node at depth :math:`d`, with :math:`d = 0` for the root, the probability
115 of not being a leaf, conditional on its existence and its ancestors only, is
117 .. math::
118 P_d = \\alpha (1+d)^{-\\beta}, \\quad
119 \\alpha \\in [0, 1], \\quad \\beta \\ge 0.
121 For a non-leaf node, conditional on existence and ancestors, the splitting
122 variable has uniform distribution amongst the variables with any splitting
123 points not used by ancestors, and the splitting point has uniform
124 distribution amongst the available ones. The splitting points are fixed,
125 tipically from the data.
127 The distribution of leaves :math:`M_j` is i.i.d. Normal with variance
128 :math:`1/m`, such that :math:`f(x)` has variance 1. In the limit
129 :math:`m\\to\\infty`, the distribution of :math:`f(x)` becomes a Gaussian
130 process.
132 Since the trees are independent, the covariance function can be computed
133 for a single tree. Consider two coordinates :math:`x` and :math:`y`, with
134 :math:`x \\le y`. Let :math:`n^-`, :math:`n^0` and :math:`n^+` be the
135 number of splitting points respectively before :math:`x`, between
136 :math:`x`, :math:`y` and after :math:`y`. Next, define :math:`\\mathbf
137 n^-`, :math:`\\mathbf n^0` and :math:`\\mathbf n^+` as the vectors of such
138 quantities for each dimension, with a total of :math:`p` dimensions, and
139 :math:`\\mathbf n = \\mathbf n^- + \\mathbf n^0 + \\mathbf n^+`. Then the
140 covariance function can be written recursively as
142 .. math::
143 \\newcommand{\\nvecs}{\\mathbf n^-, \\mathbf n^0, \\mathbf n^+}
144 k(\\mathbf x, \\mathbf y) &= k_0(\\nvecs), \\\\
145 k_D(\\nvecs) &= 1 - (1 - \\gamma) P_D,
146 \\quad \\mathbf n^0 \\ne \\mathbf 0, \\\\
147 k_d(\\mathbf 0, \\mathbf 0, \\mathbf 0) &= 1, \\\\
148 k_d(\\nvecs) &= 1 - P_d \\Bigg(1 - \\frac1{W(\\mathbf n)}
149 \\sum_{\\substack{i=1 \\\\ n_i\\ne 0}}^p
150 \\frac{w_i}{n_i} \\Bigg( \\\\
151 &\\qquad \\sum_{k=0}^{n^-_i - 1}
152 k_{d+1}(\\mathbf n^-_{n^-_i=k}, \\mathbf n^0, \\mathbf n^+)
153 + {} \\\\
154 &\\qquad \\sum_{k=0}^{n^+_i - 1}
155 k_{d+1}(\\mathbf n^-, \\mathbf n^0, \\mathbf n^+_{n^+_i=k})
156 \\Bigg)
157 \\Bigg), \\quad d < D, \\\\
158 W(\\mathbf n) &= \\sum_{\\substack{i=1 \\\\ n_i\\ne 0}}^p w_i.
160 The introduction of a maximum depth :math:`D` is necessary for
161 computational feasibility. As :math:`D` increases, the result converges to
162 the one without depth limit. For :math:`D \\le 2` (the default value), the
163 covariance is implemented in closed form and takes :math:`O(p)` to compute.
164 For :math:`D > 2`, the computational complexity grows exponentially as
165 :math:`O(p(\\bar np)^{D-2})`, where :math:`\\bar n` is the average number of splitting
166 points along a dimension.
168 In the maximum allowed depth is 1, i.e., either :math:`D = 1` or
169 :math:`\\beta\\to\\infty`, the kernel assumes the simple form
171 .. math::
172 k(\\mathbf x, \\mathbf y) &= 1 - P_0 \\left(
173 1 - Q + \\frac Q{W(\\mathbf n)}
174 \\sum_{\\substack{i=1 \\\\ n_i\\ne 0}}^p w_i
175 \\frac{n^0_i}{n_i} \\right), \\\\
176 Q &= \\begin{cases}
177 1 - (1 - \\gamma) P_1 & \\mathbf n^0 \\ne \\mathbf 0, \\\\
178 1 & \\mathbf n^0 = \\mathbf 0,
179 \\end{cases}
181 which is separable along dimensions, i.e., it has no interactions.
183 References
184 ----------
185 .. [1] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
186 Bayesian additive regression trees," The Annals of Applied Statistics,
187 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
188 """
190 splits = BART._check_splits(splits, indices) 1eabcd
191 if not x.dtype.names: 1eabcd
192 x = x[..., None] 1abcd
193 if not y.dtype.names: 1eabcd
194 y = y[..., None] 1abcd
195 if indices: 1eabcd
196 ix = BART._check_x(x) 1eabcd
197 iy = BART._check_x(y) 1eabcd
198 else:
199 ix = BART._indices_from_coord(x, splits) 1abcd
200 iy = BART._indices_from_coord(y, splits) 1abcd
201 return BART.correlation( 1eabcd
202 splits[0], ix, iy,
203 pnt=pnt, alpha=alpha, beta=beta, gamma=gamma, maxd=maxd,
204 intercept=intercept, weights=weights, reset=reset, altinput=True,
205 )
207 # TODO
208 # - make gamma='auto' depend on maxd and reset with a dictionary, error
209 # if not specified
210 # - do not require to specify splitting points if using indices
212class BART(_BARTBase): 1feabcd
214 __doc__ = _BARTBase.__doc__ 1feabcd
216 @classmethod 1feabcd
217 def splits_from_coord(cls, x): 1feabcd
218 """
219 Generate splitting points from data.
221 Parameters
222 ----------
223 x : array of numbers
224 The data. Can be passed in two formats: 1) a structured array where
225 each leaf field represents a dimension, 2) a normal array where the
226 last axis runs over dimensions. In the structured case, each
227 index in any shaped field is a different dimension.
229 Returns
230 -------
231 length : int (p,) array
232 The number of splitting points along each of ``p`` dimensions.
233 splits : (n, p) array
234 Each column contains the sorted splitting points along a dimension.
235 The splitting points are the midpoints between consecutive values
236 appearing in `x` for that dimension. Column ``splits[:, i]``
237 contains splitting points only up to ``length[i]``, while afterward
238 it is filled with a very large value.
240 """
241 x = cls._check_x(x) 1eabcd
242 return cls._splits_from_coord(x) 1eabcd
244 # TODO options like BayesTree, i.e., use an evenly spaced range
245 # instead of quantilizing, and set a maximum number of splits. Use the
246 # same parameter names as BayesTree::bart, but change the defaults.
248 @staticmethod 1feabcd
249 @jax.jit 1feabcd
250 def _splits_from_coord(x): 1feabcd
251 """
252 Jitted implementation of splits_from_coord. Applying jit avoids the
253 recompilation in lax.scan each time the method is called, and
254 splits_from_coord can not be jitted directly because x could be a numpy
255 structured array.
256 """
257 x = x.reshape(-1, x.shape[-1]) if x.size else x.reshape(1, x.shape[-1]) 1eabcd
258 if jnp.issubdtype(x.dtype, jnp.inexact): 1eabcd
259 info = jnp.finfo 1eabcd
260 else:
261 info = jnp.iinfo 1abcd
262 fill = info(x.dtype).max 1eabcd
263 def loop(_, xi): 1eabcd
264 u = jnp.unique(xi, size=xi.size, fill_value=fill) 1eabcd
265 m = jnp.where(u[1:] < fill, (u[1:] + u[:-1]) / 2, fill) 1eabcd
266 l = jnp.searchsorted(m, fill) 1eabcd
267 return _, (l, m) 1eabcd
268 _, (length, midpoints) = lax.scan(loop, None, x.T) 1eabcd
269 return length, midpoints.T 1eabcd
271 @classmethod 1feabcd
272 def indices_from_coord(cls, x, splits): 1feabcd
273 """
274 Convert coordinates to indices w.r.t. splitting points.
276 Parameters
277 ----------
278 x : array of numbers
279 The coordinates. Can be passed in two formats: 1) a structured
280 array where each leaf field represents a dimension, 2) a normal
281 array where the last axis runs over dimensions. In the structured
282 case, each index in any shaped field is a different dimension.
283 splits : pair of arrays
284 The first is an int (p,) array containing the number of splitting
285 points along each dimension, the second has shape (n, p) and
286 contains the sorted splitting points in each column, filled with
287 high values after the length.
289 Returns
290 -------
291 ix : int array
292 An array with the same shape as ``x``, unless ``x`` is a structured
293 array, in which case the last axis of ``ix`` is the flattened version
294 of the structured type. ``ix`` contains indices mapping ``x`` to
295 positions between splitting points along each coordinate, with the
296 following convention: index 0 means before the first split, index
297 i > 0 means between split i - 1 and split i.
299 """
300 splits = cls._check_splits(splits, False) 1eabcd
301 return cls._indices_from_coord(x, splits) 1eabcd
303 @classmethod 1feabcd
304 def _indices_from_coord(cls, x, checked_splits): 1feabcd
305 x = cls._check_x(x) 1eabcd
306 if x.shape[-1] != checked_splits[0].size: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true1eabcd
307 raise ValueError(f'splitting grid is for {checked_splits[0].size} '
308 f'dimensions, found {x.shape[-1]}')
309 return cls._searchsorted_vectorized(checked_splits[1], x) 1eabcd
311 @classmethod 1feabcd
312 def correlation(cls, 1feabcd
313 splitsbefore_or_totalsplits,
314 splitsbetween_or_index1,
315 splitsafter_or_index2,
316 *,
317 alpha=0.95,
318 beta=2,
319 gamma=1,
320 maxd=2,
321 debug=False,
322 pnt=None,
323 intercept=True,
324 weights=None,
325 reset=None,
326 altinput=False):
327 """
328 Compute the BART prior correlation between two points.
330 Apart from arguments ``maxd``, ``debug`` and ``reset``, this method is fully
331 vectorized.
333 Parameters
334 ----------
335 splitsbefore_or_totalsplits : int (p,) array
336 The number of splitting points less than the two points, separately
337 along each coordinate, or the total number of splits if ``altinput``.
338 splitsbetween_or_index1 : int (p,) array
339 The number of splitting points between the two points, separately
340 along each coordinate, or the index in the splitting bins of the
341 first point if ``altinput``, where 0 means to the left of the leftmost
342 splitting point.
343 splitsafter_or_index2 : int (p,) array
344 The number of splitting points greater than the two points,
345 separately along each coordinate, or the index in the splitting bins
346 of the second point if ``altinput``.
347 debug : bool
348 If True, disable shortcuts in the tree recursion. Default False.
349 altinput : bool
350 If True, take as input the indices in the splitting bins of the
351 points instead of the counts of splitting points separating them,
352 and use a different implementation optimized for that case. Default
353 False. The `BART` kernel uses ``altinput=True``.
354 Other parameters :
355 See `BART`.
357 Returns
358 -------
359 corr : scalar
360 The prior correlation.
361 """
363 # check splitting indices are integers
364 splitsbefore_or_totalsplits = jnp.asarray(splitsbefore_or_totalsplits) 1eabcd
365 splitsbetween_or_index1 = jnp.asarray(splitsbetween_or_index1) 1eabcd
366 splitsafter_or_index2 = jnp.asarray(splitsafter_or_index2) 1eabcd
367 assert jnp.issubdtype(splitsbefore_or_totalsplits.dtype, jnp.integer) 1eabcd
368 assert jnp.issubdtype(splitsbetween_or_index1.dtype, jnp.integer) 1eabcd
369 assert jnp.issubdtype(splitsafter_or_index2.dtype, jnp.integer) 1eabcd
371 # check splitting indices
372 with _jaxext.skipifabstract(): 1eabcd
373 assert jnp.all(splitsbefore_or_totalsplits >= 0), 'splitting counts must be nonnegative' 1eabcd
374 if altinput: 1eabcd
375 assert jnp.all((0 <= splitsbetween_or_index1) & (splitsbetween_or_index1 <= splitsbefore_or_totalsplits)), 'splitting index must be in [0, n]' 1eabcd
376 assert jnp.all((0 <= splitsafter_or_index2) & (splitsafter_or_index2 <= splitsbefore_or_totalsplits)), 'splitting index must be in [0, n]' 1eabcd
377 else:
378 assert jnp.all(splitsbetween_or_index1 >= 0), 'splitting counts must be nonnegative' 1abcd
379 assert jnp.all(splitsafter_or_index2 >= 0), 'splitting counts must be nonnegative' 1abcd
381 # get splitting probabilities
382 if pnt is None: 1eabcd
383 assert maxd == int(maxd) and maxd >= 0, maxd 1eabcd
384 alpha = jnp.asarray(alpha) 1eabcd
385 beta = jnp.asarray(beta) 1eabcd
386 with _jaxext.skipifabstract(): 1eabcd
387 assert jnp.all((0 <= alpha) & (alpha <= 1)), 'alpha must be in [0, 1]' 1eabcd
388 assert jnp.all(beta >= 0), 'beta must be in [0, inf)' 1eabcd
389 d = jnp.arange(maxd + 1) 1eabcd
390 alpha = alpha[..., None] 1eabcd
391 beta = beta[..., None] 1eabcd
392 pnt = alpha / (1 + d) ** beta 1eabcd
393 else:
394 pnt = jnp.asarray(pnt) 1abcd
396 # get covariate weights
397 if weights is None: 1eabcd
398 weights = jnp.ones(splitsbefore_or_totalsplits.shape[-1], pnt.dtype) 1eabcd
399 else:
400 weights = jnp.asarray(weights) 1abcd
402 # get interpolation coefficients
403 if isinstance(gamma, str): 1eabcd
404 if gamma == 'auto': 1abcd
405 assert reset is None and 1 <= pnt.shape[-1] - 1 <= 3 1abcd
406 p = weights.shape[-1] 1abcd
407 gamma = cls._gamma(p, pnt) 1abcd
408 else:
409 raise KeyError(gamma) 1abcd
410 else:
411 gamma = jnp.asarray(gamma) 1eabcd
413 # check values are in range
414 with _jaxext.skipifabstract(): 1eabcd
415 assert jnp.all((0 <= gamma) & (gamma <= 1)), 'gamma must be in [0, 1]' 1eabcd
416 assert jnp.all((0 <= pnt) & (pnt <= 1)), 'pnt must be in [0, 1]' 1eabcd
417 assert jnp.all(weights >= 0), 'weights must be in [0, inf)' 1eabcd
419 # set first splitting probability to 1 to remove flat baseline (keep
420 # last!)
421 if not intercept: 1eabcd
422 pnt = pnt.at[..., 0].set(1) 1eabcd
424 # expand and check recursion reset depths
425 if reset is None: 1eabcd
426 reset = [] 1abcd
427 if not hasattr(reset, '__len__'): 1eabcd
428 reset = [reset] 1abcd
429 reset = [0] + list(reset) + [pnt.shape[-1] - 1] 1eabcd
430 for i, j in zip(reset, reset[1:]): 1eabcd
431 assert int(j) == j and i <= j, (i, j) 1eabcd
433 # convert reset depths list to brackets with repetition
434 brackets_norep = list(zip(reset, reset[1:])) 1eabcd
435 brackets = [brackets_norep[0] + (1,)] 1eabcd
436 for t, b in brackets_norep[1:]: 1eabcd
437 lt, lb, lr = brackets[-1] 1eabcd
438 if altinput and not debug and lr * (b - t) == lb - lt and b - t <= 2: 1eabcd
439 brackets[-1] = lt, b, lr + 1 1eabcd
440 else:
441 brackets.append((t, b, 1)) 1abcd
443 # call recursive function for each recursion slice
444 corr = gamma 1eabcd
445 for t, b, repeat in reversed(brackets): 1eabcd
446 probs = pnt[..., t:b + 1] 1eabcd
447 if t > 0: 1eabcd
448 probs = probs.at[..., 0].set(1) 1abcd
449 if repeat > 1: 1eabcd
450 head = probs[..., 0:1] 1eabcd
451 one = jnp.ones_like(head) 1eabcd
452 probs = jnp.concatenate(sum(reversed([ 1eabcd
453 [head if i == 0 else one, p]
454 for i, p in enumerate(jnp.split(probs[..., 1:], repeat, axis=-1))
455 ]), start=[]), axis=-1)
456 else:
457 repeat = None 1abcd
458 corr = cls._correlation_vectorized( 1eabcd
459 splitsbefore_or_totalsplits,
460 splitsbetween_or_index1,
461 splitsafter_or_index2,
462 probs, corr, weights,
463 debug, altinput, repeat,
464 )
465 return corr 1eabcd
467 # TODO public method to compute pnt
469 @staticmethod 1feabcd
470 def _gamma(p, pnt): 1feabcd
471 # gamma(alpha, beta, maxd) =
472 # = (gamma_0 - gamma_d maxd) (1 - alpha^s 2^(-t beta)) =
473 # = (gamma_0 - gamma_d maxd) (1 - P0^s-t P1^t)
475 gamma_0 = 0.611 + 0.021 * jnp.exp(-1.3 * (p - 1)) 1abcd
476 gamma_d = -0.0034 + 0.084 * jnp.exp(-2.02 * (p - 1)) 1abcd
477 s = 2.03 - 0.69 * jnp.exp(-0.72 * (p - 1)) 1abcd
478 t = 4.01 - 1.49 * jnp.exp(-0.77 * (p - 1)) 1abcd
480 maxd = pnt.shape[-1] - 1 1abcd
481 floor = jnp.clip(gamma_0 - gamma_d * maxd, 0, 1) 1abcd
483 P0 = pnt[..., 0] 1abcd
484 P1 = jnp.minimum(P0, pnt[..., 1]) 1abcd
485 corner = jnp.where(P0, 1 - P0 ** (s - t) * P1 ** t, 1) 1abcd
487 return floor * corner 1abcd
489 # TODO make this public?
491 @staticmethod 1feabcd
492 def _check_x(x): 1feabcd
493 x = _array.asarray(x) 1eabcd
494 if x.dtype.names: 1eabcd
495 x = recfunctions.structured_to_unstructured(x) 1eabcd
496 return x 1eabcd
498 @staticmethod 1feabcd
499 def _check_splits(splits, indices): 1feabcd
500 l, s = splits 1eabcd
501 l = jnp.asarray(l) 1eabcd
502 assert l.ndim == 1 1eabcd
503 if not indices: 1eabcd
504 s = jnp.asarray(s) 1eabcd
505 assert 1 <= s.ndim <= 2 1eabcd
506 if s.ndim == 1: 1eabcd
507 s = s[:, None] 1abcd
508 assert l.size == s.shape[1] 1eabcd
509 with _jaxext.skipifabstract(): 1eabcd
510 assert jnp.all((0 <= l) & (l <= s.shape[0])), 'length out of bounds' 1eabcd
511 if not indices: 1eabcd
512 assert jnp.all(jnp.sort(s, axis=0) == s), 'unsorted splitting points' 1eabcd
513 return l, s 1eabcd
515 @staticmethod 1feabcd
516 @functools.partial(jax.jit, static_argnames=('side',)) 1feabcd
517 def _searchsorted_vectorized(A, V, **kw): 1feabcd
518 """
519 A : (n, p)
520 V : (..., p)
521 out : (..., p)
522 """
523 def loop(_, av): 1eabcd
524 return _, jnp.searchsorted(*av, **kw) 1eabcd
525 _, out = lax.scan(loop, None, (A.T, V.T)) 1eabcd
526 return out.T 1eabcd
528 @classmethod 1feabcd
529 @functools.partial(jax.jit, static_argnums=(0, 7)) 1feabcd
530 def _correlation_old(cls, nminus, n0, nplus, pnt, gamma, w, debug): 1feabcd
531 """ old version, kept around for cross-checking """
533 assert nminus.shape == n0.shape == nplus.shape == w.shape 1abcd
534 assert nminus.ndim == 1 and nminus.size >= 0 1abcd
535 assert pnt.ndim == 1 and pnt.size > 0 1abcd
536 # TODO repeat this shape checks in BART.correlation such that the
537 # error messages are user-legible
539 # optimization to avoid looping over ignored axes
540 nminus = jnp.where(w, nminus, 0) 1abcd
541 n0 = jnp.where(w, n0, 0) 1abcd
542 nplus = jnp.where(w, nplus, 0) 1abcd
544 float_type = _jaxext.float_type(pnt, gamma, w) 1abcd
546 if nminus.size == 0: 1abcd
547 return jnp.array(1, float_type) 1abcd
549 anyn0 = jnp.any(jnp.logical_and(n0, w)) 1abcd
551 if pnt.size == 1: 1abcd
552 return jnp.where(anyn0, 1 - (1 - gamma) * pnt[0], 1) 1abcd
554 nout = nminus + nplus 1abcd
555 n = nout + n0 1abcd
556 Wn = jnp.sum(jnp.where(n, w, 0)) # <-- @ 1abcd
558 if pnt.size == 2 and not debug: 1abcd
559 Q = 1 - (1 - gamma) * pnt[1] 1abcd
560 sump = Q * jnp.sum(jnp.where(n, w * nout / n, 0)) # <-- @ 1abcd
561 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd
563 if pnt.size == 3 and not debug: 1abcd
564 Q = 1 - (1 - gamma) * pnt[2] 1abcd
565 s = w * nout / n 1abcd
566 S = jnp.sum(jnp.where(n, s, 0)) # <-- @ 1abcd
567 t = w * n0 / n 1abcd
568 psin = jspecial.digamma(n.astype(float_type)) 1abcd
569 def terms(nminus, nplus): 1abcd
570 nminus0 = nminus + n0 1abcd
571 Wnmod = Wn - jnp.where(nminus0, 0, w) 1abcd
572 frac = jnp.where(nminus0, w * nminus / nminus0, 0) 1abcd
573 terms1 = (S - s + frac) / Wnmod 1abcd
574 psi1nminus0 = jspecial.digamma((1 + nminus0).astype(float_type)) 1abcd
575 terms2 = ((nplus - 1) * (S + t) - w * n0 * (psin - psi1nminus0)) / Wn 1abcd
576 return jnp.where(nplus, terms1 + terms2, 0) 1abcd
577 tplus = terms(nminus, nplus) 1abcd
578 tminus = terms(nplus, nminus) 1abcd
579 tall = jnp.where(n, w * (tplus + tminus) / n, 0) 1abcd
580 sump = (1 - pnt[1]) * S + pnt[1] * Q * jnp.sum(tall) # <-- @ 1abcd
581 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd
583 # TODO the pnt.size == 3 calculation is probably less accurate than
584 # the recursive one, see comparison limits > 30 ULP in test_bart.py
586 p = len(nminus) 1abcd
588 val = (0., nminus, n0, nplus) 1abcd
589 def loop(i, val): 1abcd
590 sump, nminus, n0, nplus = val 1abcd
592 nminusi = nminus[i] 1abcd
593 n0i = n0[i] 1abcd
594 nplusi = nplus[i] 1abcd
595 ni = nminusi + n0i + nplusi 1abcd
597 val = (0., nminus, n0, nplus, i, nminusi) 1abcd
598 def loop(k, val): 1abcd
599 sumn, nminus, n0, nplus, i, nminusi = val 1abcd
601 # here I use the fact that .at[].set won't set the value if the
602 # index is out of bounds
603 nminus = nminus.at[jnp.where(k < nminusi, i, i + p)].set(k) 1abcd
604 nplus = nplus.at[jnp.where(k >= nminusi, i, i + p)].set(k - nminusi) 1abcd
606 sumn += cls._correlation_old(nminus, n0, nplus, pnt[1:], gamma, w, debug) 1abcd
608 nminus = nminus.at[i].set(nminusi) 1abcd
609 nplus = nplus.at[i].set(nplusi) 1abcd
611 return sumn, nminus, n0, nplus, i, nminusi 1abcd
613 # if ni == 0 I skip recursion by passing 0 as iteration end
614 end = jnp.where(ni, nminusi + nplusi, 0) 1abcd
615 start = jnp.zeros_like(end) 1abcd
616 sumn, nminus, n0, nplus, _, _ = lax.fori_loop(start, end, loop, val) 1abcd
618 sump += jnp.where(ni, w[i] * sumn / ni, 0) 1abcd
620 return sump, nminus, n0, nplus 1abcd
622 # skip summation if all(n0 == 0)
623 end = jnp.where(anyn0, p, 0) 1abcd
624 sump, _, _, _ = lax.fori_loop(0, end, loop, val) 1abcd
626 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd
628 @staticmethod 1feabcd
629 def _scan_but_first(f, init, xs): 1feabcd
630 """ lax.scan, but execute separately the first cycle. The point is that
631 I use it when the first cycle works on smaller arrays due to
632 broadcasting. """
633 assert isinstance(xs, jnp.ndarray) 1eabcd
634 assert len(xs) > 0 1eabcd
635 init, out = f(init, xs[0]) 1eabcd
636 assert out is None 1eabcd
637 if len(xs) == 1: 1eabcd
638 return init, out 1abcd
639 elif len(xs) == 2: 1eabcd
640 return f(init, xs[1]) 1abcd
641 else:
642 return lax.scan(f, init, xs[1:]) 1eabcd
644 @classmethod 1feabcd
645 @functools.partial(jax.jit, static_argnums=(0, 7, 8)) 1feabcd
646 def _correlation(cls, n, ix, iy, pnt, gamma, w, debug, repeat): 1feabcd
647 # this implementation is optimized assuming that the shapes are as
648 # follows:
649 # n (p,)
650 # ix (n, 1, p)
651 # iy (1, n, p)
652 # pnt (d,)
653 # gamma () or (n, n)
654 # w (p,)
656 assert n.ndim == 1 1eabcd
657 assert n.shape == ix.shape == iy.shape == w.shape 1eabcd
658 assert pnt.ndim == 1 and pnt.size > 0 1eabcd
659 assert gamma.ndim == 0 1eabcd
660 # TODO repeat this shape checks in BART.correlation such that the
661 # error messages are user-legible
663 # check the strict conditions under which `repeat` is implemented
664 if repeat is not None: 1eabcd
665 assert ( 1a
666 not debug
667 and repeat > 0
668 and pnt.size % repeat == 0
669 and pnt.size // repeat <= 3
670 )
671 else:
672 repeat = 1 1abcd
674 # infer float type from float arguments
675 flt = _jaxext.float_type(pnt, gamma, w) 1eabcd
677 # no covariates, always return 1
678 if n.size == 0: 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true1eabcd
679 return jnp.array(1, flt)
681 # pre-cast all floats to the common type, to avoid unwanted float32
682 # calculations in mixed float-integer operations
683 pnt = pnt.astype(flt) 1eabcd
684 gamma = gamma.astype(flt) 1eabcd
685 w = w.astype(flt) 1eabcd
687 # ignore zero-weight axes
688 n = jnp.where(w, n, 0) 1eabcd
689 ix = jnp.where(w, ix, 0) 1eabcd
690 iy = jnp.where(w, iy, 0) 1eabcd
692 # check if the points coincide
693 seed = jnp.uint64(16132933535611723338) 1eabcd
694 hx = _jaxext.fasthash64(ix, seed) 1eabcd
695 hy = _jaxext.fasthash64(iy, seed) 1eabcd
696 anyn0 = hx != hy 1eabcd
697 # no hash collision checking, it would be branchless because of vmap,
698 # the probability of collision building a nxn matrix with n=10000 is
699 # -expm1(10000**2 * log1p(-1/2**64)) = 5e-12.
701 # base case of the recursion, no dependence on points apart from the
702 # case when they are equal
703 if pnt.size // repeat == 1: 1eabcd
704 def loop(carry, pnt): 1abcd
705 anyn0, gamma = carry 1abcd
706 gamma = jnp.where(anyn0, 1 - (1 - gamma) * pnt[0], 1) 1abcd
707 return (anyn0, gamma), None 1abcd
708 (_, gamma), _ = cls._scan_but_first(loop, (anyn0, gamma), pnt.reshape(repeat, -1)) 1abcd
709 return gamma 1abcd
711 # normalization for axes weights
712 Wn = jnp.sum(jnp.where(n, w, 0)) 1eabcd
714 # shortcut for the last two levels of the recursion
715 if pnt.size // repeat == 2 and not debug: 1eabcd
716 n0 = jnp.abs(ix - iy) 1abcd
717 sum_term = jnp.where(n, w / n, 0) @ n0 1abcd
718 def loop(carry, pnt): 1abcd
719 anyn0, Wn, sum_term, gamma = carry 1abcd
720 Q = 1 - pnt[1] + gamma * pnt[1] 1abcd
721 P0 = pnt[0] 1abcd
722 result = 1 - P0 + Q * (P0 - P0 / Wn * sum_term) 1abcd
723 gamma = jnp.where(anyn0, result, 1) 1abcd
724 return (anyn0, Wn, sum_term, gamma), None 1abcd
725 (_, _, _, gamma), _ = cls._scan_but_first(loop, (anyn0, Wn, sum_term, gamma), pnt.reshape(repeat, -1)) 1abcd
726 return gamma 1abcd
728 # convert to alternative format
729 xlty = ix < iy 1eabcd
730 minxy = jnp.where(xlty, ix, iy) 1eabcd
731 maxxy = jnp.where(xlty, iy, ix) 1eabcd
732 n0 = maxxy - minxy 1eabcd
734 # shortcut for the last three levels of the recursion
735 if pnt.size // repeat == 3 and not debug: 1eabcd
736 nminus0 = maxxy 1eabcd
737 nplus0 = n - minxy 1eabcd
738 nout = n - n0 1eabcd
740 inv_Wn = 1 / Wn 1eabcd
741 inv_Wnmod = 1 / (Wn - jnp.where(n, w, 0)) 1eabcd
742 inv_Wnminus = jnp.where(nplus0, inv_Wn, inv_Wnmod) 1eabcd
743 inv_Wnplus = jnp.where(nminus0, inv_Wn, inv_Wnmod) 1eabcd
744 wn = jnp.where(n, w / n, 0) 1eabcd
745 S = wn @ nout 1eabcd
747 t = wn * n0 1eabcd
748 terms1 = (S + t) * (inv_Wnminus + inv_Wnplus + inv_Wn * (nout - 2)) 1eabcd
750 terms2 = jnp.where( nplus0, w * inv_Wn * n0 / nplus0, w * inv_Wnmod) 1eabcd
751 terms2 += jnp.where(nminus0, w * inv_Wn * n0 / nminus0, w * inv_Wnmod) 1eabcd
753 psin = jspecial.digamma(jnp.where(n, n, 1).astype(flt)) 1eabcd
754 psiminus = jnp.where(xlty, 1eabcd
755 jspecial.digamma((1 + iy).astype(flt)),
756 jspecial.digamma((1 + ix).astype(flt)),
757 )
758 psiplus = jnp.where(xlty, 1eabcd
759 jspecial.digamma((1 + n - ix).astype(flt)),
760 jspecial.digamma((1 + n - iy).astype(flt)),
761 )
762 terms3 = w * inv_Wn * n0 * (2 * psin - psiminus - psiplus) 1eabcd
764 terms = terms1 - terms2 - terms3 1eabcd
765 sumi = wn @ terms 1eabcd
767 def loop(carry, pnt): 1eabcd
768 anyn0, inv_Wn, S, sumi, gamma = carry 1eabcd
769 Q = 1 + pnt[2] * (gamma - 1) 1eabcd
770 sump = S + pnt[1] * (Q * sumi - S) 1eabcd
771 result = 1 + pnt[0] * (inv_Wn * sump - 1) 1eabcd
772 gamma = jnp.where(anyn0, result, 1) 1eabcd
773 return (anyn0, inv_Wn, S, sumi, gamma), None 1eabcd
774 (_, _, _, _, gamma), _ = cls._scan_but_first(loop, (anyn0, inv_Wn, S, sumi, gamma), pnt.reshape(repeat, -1)) 1eabcd
775 return gamma 1eabcd
777 # finish conversion to alternative format
778 nminus = minxy 1abcd
779 nplus = n - maxxy 1abcd
780 p = len(nminus) 1abcd
781 del ix, iy, maxxy, minxy 1abcd
783 val = (0., nminus, n0, nplus) 1abcd
784 def loop(i, val): 1abcd
785 sump, nminus, n0, nplus = val 1abcd
787 nminusi = nminus[i] 1abcd
788 n0i = n0[i] 1abcd
789 nplusi = nplus[i] 1abcd
790 ni = nminusi + n0i + nplusi 1abcd
792 val = (0., nminus, n0, nplus, i, nminusi) 1abcd
793 def loop(k, val): 1abcd
794 sumn, nminus, n0, nplus, i, nminusi = val 1abcd
796 # here I use the fact that .at[].set won't set the value if the
797 # index is out of bounds
798 nminus = nminus.at[jnp.where(k < nminusi, i, i + p)].set(k) 1abcd
799 nplus = nplus.at[jnp.where(k >= nminusi, i, i + p)].set(k - nminusi) 1abcd
801 n = nminus + n0 + nplus 1abcd
802 ix = nminus 1abcd
803 iy = nminus + n0 1abcd
804 sumn += cls._correlation(n, ix, iy, pnt[1:], gamma, w, debug, None) 1abcd
806 nminus = nminus.at[i].set(nminusi) 1abcd
807 nplus = nplus.at[i].set(nplusi) 1abcd
809 return sumn, nminus, n0, nplus, i, nminusi 1abcd
811 # if ni == 0 I skip recursion by passing 0 as iteration end
812 end = jnp.where(ni, nminusi + nplusi, 0) 1abcd
813 start = jnp.zeros_like(end) 1abcd
814 sumn, nminus, n0, nplus, _, _ = lax.fori_loop(start, end, loop, val) 1abcd
816 sump += jnp.where(ni, w[i] * sumn / ni, 0) 1abcd
818 return sump, nminus, n0, nplus 1abcd
820 # skip summation if all(n0 == 0)
821 end = jnp.where(anyn0, p, 0) 1abcd
822 sump, _, _, _ = lax.fori_loop(0, end, loop, val) 1abcd
824 return jnp.where(anyn0, 1 - pnt[0] * (1 - sump / Wn), 1) 1abcd
826 @classmethod 1feabcd
827 @functools.partial(jnp.vectorize, excluded=(0, 7, 8, 9), signature='(p),(p),(p),(d),(),(p)->()') 1feabcd
828 def _correlation_vectorized(cls, nminus_or_n, n0_or_ix, nplus_or_iy, pnt, gamma, w, debug, altinput, repeat): 1feabcd
829 if altinput: 1eabcd
830 func = lambda *args: cls._correlation(*args, repeat) 1eabcd
831 else:
832 func = cls._correlation_old 1abcd
833 return func(nminus_or_n, n0_or_ix, nplus_or_iy, pnt, gamma, w, bool(debug)) 1eabcd