Coverage for src/lsqfitgp/_kernels/_basic.py: 100%
113 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_kernels/_basic.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import sys 1feabcd
21import re 1feabcd
22import collections 1feabcd
24import numpy 1feabcd
25import jax 1feabcd
26from jax import numpy as jnp 1feabcd
27from jax.scipy import special as jspecial 1feabcd
29from .. import _special 1feabcd
30from .. import _jaxext 1feabcd
31from .. import _Kernel 1feabcd
32from .._Kernel import kernel, stationarykernel, isotropickernel 1feabcd
34@isotropickernel(derivable=True, input='raw') 1feabcd
35def Constant(x, y): 1feabcd
36 """
37 Constant kernel.
39 .. math::
40 k(x, y) = 1
42 This means that all points are completely correlated, thus it is equivalent
43 to fitting with a horizontal line. This can be seen also by observing that
44 1 = 1 x 1.
45 """
46 return jnp.ones(jnp.broadcast_shapes(x.shape, y.shape)) 1eabcd
48@isotropickernel(derivable=False, input='raw') 1feabcd
49def White(x, y): 1feabcd
50 """
51 White noise kernel.
53 .. math::
54 k(x, y) = \\begin{cases}
55 1 & x = y \\\\
56 0 & x \\neq y
57 \\end{cases}
58 """
59 return _Kernel.prod_recurse_dtype(lambda x, y: x == y, x, y).astype(int) 1feabcd
60 # TODO maybe StructuredArray should support equality and other operations
62@isotropickernel(derivable=True) 1feabcd
63def ExpQuad(r2): 1feabcd
64 """
65 Exponential quadratic kernel.
67 .. math::
68 k(r) = \\exp \\left( -\\frac 12 r^2 \\right)
70 It is smooth and has a strict typical lengthscale, i.e., oscillations are
71 strongly suppressed under a certain wavelength, and correlations are
72 strongly suppressed over a certain distance.
74 Reference: Rasmussen and Williams (2006, p. 83).
75 """
76 return jnp.exp(-1/2 * r2) 1feabcd
78def _dot(x, y): 1feabcd
79 return _Kernel.sum_recurse_dtype(lambda x, y: x * y, x, y) 1eabcd
81@kernel(derivable=True) 1feabcd
82def Linear(x, y): 1feabcd
83 """
84 Dot product kernel.
86 .. math::
87 k(x, y) = x \\cdot y = \\sum_i x_i y_i
89 In 1D it is equivalent to fitting with a line passing by the origin.
91 Reference: Rasmussen and Williams (2006, p. 89).
92 """
93 return _dot(x, y) 1abcd
95@isotropickernel(derivable=lambda gamma=1: gamma == 2) 1feabcd
96def GammaExp(r2, gamma=1): 1feabcd
97 """
98 Gamma exponential kernel.
100 .. math::
101 k(r) = \\exp(-r^\\gamma), \\quad
102 \\gamma \\in (0, 2]
104 For :math:`\\gamma = 2` it is the squared exponential kernel, for
105 :math:`\\gamma = 1` (default) it is the Matérn 1/2 kernel, for
106 :math:`\\gamma \\to 0` it tends to white noise plus a constant. The process
107 is differentiable only for :math:`\\gamma = 2`, however as :math:`\\gamma`
108 gets closer to 2 the variance of the non-derivable component goes to zero.
110 Reference: Rasmussen and Williams (2006, p. 86).
111 """
112 with _jaxext.skipifabstract(): 1fabcd
113 assert 0 < gamma <= 2, gamma 1fabcd
114 nondiff = jnp.exp(-(r2 ** (gamma / 2))) 1fabcd
115 diff = jnp.exp(-r2) 1fabcd
116 return jnp.where(gamma == 2, diff, nondiff) 1fabcd
117 # I need to keep separate the case where derivatives w.r.t. r2 could be
118 # computed because the second derivative at x=0 of x^p with floating
119 # point p is nan.
121 # TODO extend to gamma=0, the correct limit is
122 # e^-1 constant + (1 - e^-1) white noise. Use lax.switch.
124 # TODO derivatives w.r.t. gamma at gamma==2 are probably broken, although
125 # I guess they are not needed since it's on the boundary of the domain
127@kernel(derivable=True) 1feabcd
128def NNKernel(x, y, sigma0=1): 1feabcd
129 """
130 Neural network kernel.
132 .. math::
133 k(x, y) = \\frac 2 \\pi
134 \\arcsin \\left( \\frac
135 {
136 2 (q + x \\cdot y)
137 }{
138 (1 + 2 (q + x \\cdot x))
139 (1 + 2 (q + y \\cdot y))
140 }
141 \\right),
142 \\quad q = \\texttt{sigma0}^2
144 Kernel which is equivalent to a neural network with one infinite hidden
145 layer with Gaussian priors on the weights and error function response. In
146 other words, you can think of the process as a superposition of sigmoids
147 where ``sigma0`` sets the dispersion of the centers of the sigmoids.
149 Reference: Rasmussen and Williams (2006, p. 90).
150 """
152 # TODO the `2`s in the formula are a bit arbitrary. Remove them or give
153 # motivation relative to the precise formulation of the neural network.
154 with _jaxext.skipifabstract(): 1eabcd
155 assert 0 < sigma0 < jnp.inf 1eabcd
156 q = sigma0 ** 2 1eabcd
157 denom = (1 + 2 * (q + _dot(x, x))) * (1 + 2 * (q + _dot(y, y))) 1eabcd
158 return 2/jnp.pi * jnp.arcsin(2 * (q + _dot(x, y)) / denom) 1eabcd
160 # TODO this is not fully equivalent to an arbitrary transformation on the
161 # augmented vector even if x and y are transformed, unless I support q
162 # being a vector or an additional parameter.
164 # TODO if arcsin has positive taylor coefficients, this can be obtained as
165 # arcsin(1 + linear) * rescaling.
167@kernel 1feabcd
168def Gibbs(x, y, scalefun=lambda x: 1): 1feabcd
169 """
170 Gibbs kernel.
172 .. math::
173 k(x, y) = \\sqrt{ \\frac {2 s(x) s(y)} {s(x)^2 + s(y)^2} }
174 \\exp \\left( -\\frac {(x - y)^2} {s(x)^2 + s(y)^2} \\right),
175 \\quad s = \\texttt{scalefun}.
177 Kernel which in some sense is like a Gaussian kernel where the scale
178 changes at every point. The scale is computed by the parameter `scalefun`
179 which must be a callable taking the x array and returning a scale for each
180 point. By default ``scalefun`` returns 1 so it is a Gaussian kernel.
182 Consider that the default parameter ``scale`` acts before ``scalefun``, so
183 for example if ``scalefun(x) = x`` then ``scale`` has no effect. You should
184 include all rescalings in ``scalefun`` to avoid surprises.
186 Reference: Rasmussen and Williams (2006, p. 93).
187 """
188 sx = scalefun(x) 1eabcd
189 sy = scalefun(y) 1eabcd
190 with _jaxext.skipifabstract(): 1eabcd
191 assert jnp.all(sx > 0) 1eabcd
192 assert jnp.all(sy > 0) 1eabcd
193 denom = sx ** 2 + sy ** 2 1eabcd
194 factor = jnp.sqrt(2 * sx * sy / denom) 1eabcd
195 distsq = _Kernel.sum_recurse_dtype(lambda x, y: (x - y) ** 2, x, y) 1eabcd
196 return factor * jnp.exp(-distsq / denom) 1eabcd
198@stationarykernel(derivable=True, maxdim=1) 1feabcd
199def Periodic(delta, outerscale=1): 1feabcd
200 r"""
201 Periodic Gaussian kernel.
203 .. math::
204 k(\Delta) = \exp \left(
205 -2 \left(
206 \frac {\sin(\Delta / 2)} {\texttt{outerscale}}
207 \right)^2
208 \right)
210 A Gaussian kernel over a transformed periodic space. It represents a
211 periodic process. The usual `scale` parameter sets the period, with the
212 default ``scale=1`` giving a period of 2π, while `outerscale` sets the
213 length scale of the correlations.
215 Reference: Rasmussen and Williams (2006, p. 92).
216 """
217 with _jaxext.skipifabstract(): 1fabcd
218 assert 0 < outerscale < jnp.inf 1fabcd
219 return jnp.exp(-2 * (jnp.sin(delta / 2) / outerscale) ** 2) 1fabcd
221@kernel(derivable=False, maxdim=1) 1feabcd
222def Categorical(x, y, cov=None): 1feabcd
223 r"""
224 Categorical kernel.
226 .. math::
227 k(x, y) = \texttt{cov}[x, y]
229 A kernel over integers from 0 to N-1. The parameter `cov` is the covariance
230 matrix of the values.
231 """
233 # TODO support sparse matrix for cov (replace jnp.asarray and numpy
234 # check)
236 assert jnp.issubdtype(x.dtype, jnp.integer) 1fabcd
237 cov = jnp.asarray(cov) 1fabcd
238 assert cov.ndim == 2 1fabcd
239 assert cov.shape[0] == cov.shape[1] 1fabcd
240 with _jaxext.skipifabstract(): 1fabcd
241 assert jnp.allclose(cov, cov.T) 1fabcd
242 return cov[x, y] 1fabcd
244@kernel 1feabcd
245def Rescaling(x, y, stdfun=None): 1feabcd
246 r"""
247 Outer product kernel.
249 .. math::
250 k(x, y) = \texttt{stdfun}(x) \texttt{stdfun}(y)
252 A totally correlated kernel with arbitrary variance. Parameter `stdfun`
253 must be a function that takes ``x`` or ``y`` and computes the standard
254 deviation at the point. It can yield negative values; points with the same
255 sign of `stdfun` will be totally correlated, points with different sign will
256 be totally anticorrelated. Use this kernel to modulate the variance of
257 other kernels. By default `stdfun` returns a constant, so it is equivalent
258 to `Constant`.
260 """
261 if stdfun is None: 1fabcd
262 stdfun = lambda x: jnp.ones(x.shape) 1abcd
263 # do not use np.ones_like because it does not recognize StructuredArray
264 # do not use x.dtype because it could be structured
265 return stdfun(x) * stdfun(y) 1fabcd
267@stationarykernel(derivable=False, input='abs', maxdim=1) 1feabcd
268def Expon(delta): 1feabcd
269 """
270 Exponential kernel.
272 .. math::
273 k(\\Delta) = \\exp(-|\\Delta|)
275 In 1D it is equivalent to the Matérn 1/2 kernel, however in more dimensions
276 it acts separately while the Matérn kernel is isotropic.
278 Reference: Rasmussen and Williams (2006, p. 85).
279 """
280 return jnp.exp(-delta) 1abcd
282 # TODO rename Laplace, write it in terms of the 1-norm directly, then do
283 # TruncLaplace that reaches zero over a prescribed box. (Or truncate with
284 # an option truncbox=[(l0, r0), (l1, r1), ...]).
286_bow_regexp = re.compile(r'\s|[!«»"“”‘’/()\'?¡¿„‚<>,;.:-–—]') 1feabcd
288@kernel(derivable=False, maxdim=1) 1feabcd
289@numpy.vectorize 1feabcd
290def BagOfWords(x, y): 1feabcd
291 """
292 Bag of words kernel.
294 .. math::
295 k(x, y) &= \\sum_{w \\in \\text{words}} c_w(x) c_w(y), \\\\
296 c_w(x) &= \\text{number of times word $w$ appears in $x$}
298 The words are defined as non-empty substrings delimited by spaces or one of
299 the following punctuation characters: ! « » " “ ” ‘ ’ / ( ) ' ? ¡ ¿ „ ‚ < >
300 , ; . : - – —.
302 Reference: Rasmussen and Williams (2006, p. 100).
303 """
305 # TODO precompute the bags for x and y, then call a vectorized private
306 # function.
308 # TODO iterate on the shorter bag and use get on the other instead of
309 # computing set intersection? Or: convert words to integers and then do
310 # set intersection with sorted arrays?
312 xbag = collections.Counter(_bow_regexp.split(x)) 1abcd
313 ybag = collections.Counter(_bow_regexp.split(y)) 1abcd
314 xbag[''] = 0 # why this? I can't recall 1abcd
315 ybag[''] = 0 1abcd
316 common = set(xbag) & set(ybag) 1abcd
317 return sum(xbag[k] * ybag[k] for k in common) 1abcd
319# TODO add bag of characters and maybe other text kernels
321@stationarykernel(derivable=False, input='abs', maxdim=1) 1feabcd
322def HoleEffect(delta): 1feabcd
323 """
325 Hole effect kernel.
327 .. math:: k(\\Delta) = (1 - \\Delta) \\exp(-\\Delta)
329 Reference: Dietrich and Newsam (1997, p. 1096).
331 """
332 return (1 - delta) * jnp.exp(-delta) 1abcd
334def _cauchy_derivable(alpha=2, **_): 1feabcd
335 return alpha == 2 1eabcd
337@isotropickernel(derivable=_cauchy_derivable) 1feabcd
338def Cauchy(r2, alpha=2, beta=2): 1feabcd
339 r"""
340 Generalized Cauchy kernel.
342 .. math::
343 k(r) = \left(1 + \frac{r^\alpha}{\beta} \right)^{-\beta/\alpha},
344 \quad \alpha \in (0, 2], \beta > 0.
346 In the geostatistics literature, the case :math:`\alpha=2` and
347 :math:`\beta=2` (default) is known as the Cauchy kernel. In the machine
348 learning literature, the case :math:`\alpha=2` (for any :math:`\beta`) is
349 known as the rational quadratic kernel. For :math:`\beta\to\infty` it is
350 equivalent to ``GammaExp(gamma=alpha, scale=alpha ** (1/alpha))``, while
351 for :math:`\beta\to 0` to ``Constant``. It is smooth only for
352 :math:`\alpha=2`.
354 References: Gneiting and Schlather (2004, p. 273), Rasmussen and Williams
355 (2006, p. 86).
357 """
358 with _jaxext.skipifabstract(): 1eabcd
359 assert 0 < alpha <= 2, alpha 1eabcd
360 assert 0 < beta, beta 1eabcd
361 power = jnp.where(alpha == 2, r2, r2 ** (alpha / 2)) 1eabcd
362 # I need to keep separate the case where derivatives w.r.t. r2 could be
363 # computed because the second derivative at x=0 of x^p with floating
364 # point p is nan.
365 return (1 + power / beta) ** (-beta / alpha) 1eabcd
367 # TODO derivatives w.r.t. alpha at alpha==2 are probably broken, although
368 # I guess they are not needed since it's on the boundary of the domain
370@isotropickernel(derivable=lambda alpha=1: alpha == 0, input='posabs') 1feabcd
371def CausalExpQuad(r, alpha=1): 1feabcd
372 r"""
373 Causal exponential quadratic kernel.
375 .. math::
376 k(r) = \big(1 - \operatorname{erf}(\alpha r/4)\big)
377 \exp\left(-\frac12 r^2 \right)
379 From https://github.com/wesselb/mlkernels.
380 """
381 with _jaxext.skipifabstract(): 1abcd
382 assert alpha >= 0, alpha 1abcd
383 return jspecial.erfc(alpha / 4 * r) * jnp.exp(-1/2 * jnp.square(r)) 1abcd
384 # TODO taylor-expand erfc near 0 and use r2
386 # TODO is the erfc part a standalone valid kernel? If so, separate it,
387 # since this can be obtained as the product
389@kernel(derivable=True, maxdim=1) 1feabcd
390def Decaying(x, y, alpha=1): 1feabcd
391 r"""
392 Decaying kernel.
394 .. math::
395 k(x, y) =
396 \frac{1}{(1 + x + y)^\alpha},
397 \quad x, y, \alpha \ge 0
399 Reference: Swersky, Snoek and Adams (2014).
400 """
401 # TODO high dimensional version of this, see mlkernels issue #3
402 with _jaxext.skipifabstract(): 1abcd
403 assert jnp.all(x >= 0) 1abcd
404 assert jnp.all(y >= 0) 1abcd
405 return 1 / (x + y + 1) ** alpha 1abcd
406 # use x + y + 1 instead of 1 + x + y because the latter is less numerically
407 # accurate and symmetric for small x and y
409@isotropickernel(derivable=False, input='posabs') 1feabcd
410def Log(r): 1feabcd
411 """
412 Log kernel.
414 .. math::
415 k(r) = \\log(1 + r) / r
417 From https://github.com/wesselb/mlkernels.
418 """
419 return jnp.log1p(r) / r 1abcd
421@kernel(derivable=True, maxdim=1) 1feabcd
422def Taylor(x, y): 1feabcd
423 """
424 Exponential-like power series kernel.
426 .. math::
427 k(x, y) = \\sum_{k=0}^\\infty \\frac {x^k}{k!} \\frac {y^k}{k!}
428 = I_0(2 \\sqrt{xy})
430 It is equivalent to fitting with a Taylor series expansion in zero with
431 independent priors on the coefficients k with mean zero and standard
432 deviation 1/k!.
433 """
435 mul = x * y 1abcd
436 val = 2 * jnp.sqrt(jnp.abs(mul)) 1abcd
437 return jnp.where(mul >= 0, jspecial.i0(val), _special.j0(val)) 1abcd
439 # TODO reference? Maybe it's called bessel kernel in the literature?
440 # => nope, bessel kernel is the J_v one
442 # TODO what is the "natural" extension of this to multidim?
444 # TODO probably the rescaled version of this (e^-x) makes more sense