Coverage for src/lsqfitgp/_kernels/_randomwalk.py: 100%
63 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_kernels/_randomwalk.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import jax 1feabcd
21from jax import numpy as jnp 1feabcd
23from .. import _jaxext 1feabcd
24from .._Kernel import kernel, stationarykernel 1feabcd
26@kernel(derivable=False, maxdim=1) 1feabcd
27def Wiener(x, y): 1feabcd
28 """
29 Wiener kernel.
31 .. math::
32 k(x, y) = \\min(x, y), \\quad x, y > 0
34 A kernel representing a non-differentiable random walk starting at 0.
36 Reference: Rasmussen and Williams (2006, p. 94).
37 """
38 with _jaxext.skipifabstract(): 1feabcd
39 assert jnp.all(x >= 0) 1feabcd
40 assert jnp.all(y >= 0) 1feabcd
41 return jnp.minimum(x, y) 1feabcd
43def _fracbrownian_derivable(H=1/2, K=1): 1feabcd
44 return H == 1 and K == 1 1eabcd
45 # TODO fails under tracing, return None if not concrete
47@kernel(derivable=_fracbrownian_derivable, maxdim=1) 1feabcd
48def FracBrownian(x, y, H=1/2, K=1): 1feabcd
49 """
50 Bifractional Brownian motion kernel.
52 .. math::
53 k(x, y) = \\frac 1{2^K} \\big(
54 (|x|^{2H} + |y|^{2H})^K - |x-y|^{2HK}
55 \\big), \\quad H, K \\in (0, 1]
57 For :math:`H = 1/2` (default) it is the Wiener kernel. For :math:`H \\in (0, 1/2)` the
58 increments are anticorrelated (strong oscillation), for :math:`H \\in (1/2, 1]`
59 the increments are correlated (tends to keep a slope).
61 Reference: Houdré and Villa (2003).
62 """
64 # TODO I think the correlation between successive same step increments
65 # is 2^(2H-1) - 1 in (-1/2, 1). Maybe add this to the docstring.
67 with _jaxext.skipifabstract(): 1eabcd
68 assert 0 < H <= 1, H 1eabcd
69 assert 0 < K <= 1, K 1eabcd
70 H2 = 2 * H 1eabcd
71 return 1 / 2 ** K * ((jnp.abs(x) ** H2 + jnp.abs(y) ** H2) ** K - jnp.abs(x - y) ** (H2 * K)) 1eabcd
73# redefine derivatives of min and max because jax default is to yield 1/2
74# when x == y, while I need 1 but consistently between min/max
76@jax.custom_jvp 1feabcd
77def _minimum(x, y): 1feabcd
78 return jnp.minimum(x, y) 1abcd
80@_minimum.defjvp 1feabcd
81def _minimum_jvp(primals, tangents): 1feabcd
82 x, y = primals 1abcd
83 xdot, ydot = tangents 1abcd
84 return _minimum(x, y), jnp.where(x < y, xdot, ydot) 1abcd
86@jax.custom_jvp 1feabcd
87def _maximum(x, y): 1feabcd
88 return jnp.maximum(x, y) 1abcd
90@_maximum.defjvp 1feabcd
91def _maximum_jvp(primals, tangents): 1feabcd
92 x, y = primals 1abcd
93 xdot, ydot = tangents 1abcd
94 return _maximum(x, y), jnp.where(x >= y, xdot, ydot) 1abcd
96@kernel(derivable=1, maxdim=1) 1feabcd
97def WienerIntegral(x, y): 1feabcd
98 """
99 Kernel for a process whose derivative is a Wiener process.
101 .. math::
102 k(x, y) = \\frac 12 a^2 \\left(b - \\frac a3 \\right),
103 \\quad a = \\min(x, y), b = \\max(x, y)
105 """
107 # TODO can I generate this algorithmically for arbitrary integration order?
108 # If I don't find a closed formula I can use sympy. =>
109 # JuliaGaussianProcesses implements it, copy their code
111 with _jaxext.skipifabstract(): 1abcd
112 assert jnp.all(x >= 0) 1abcd
113 assert jnp.all(y >= 0) 1abcd
114 a = _minimum(x, y) 1abcd
115 b = _maximum(x, y) 1abcd
116 return 1/2 * a ** 2 * (b - a / 3) 1abcd
118@kernel(derivable=False, maxdim=1) 1feabcd
119def OrnsteinUhlenbeck(x, y): 1feabcd
120 """
121 Ornstein-Uhlenbeck process kernel.
123 .. math::
124 k(x, y) = \\exp(-|x - y|) - \\exp(-(x + y)),
125 \\quad x, y \\ge 0
127 It is a random walk plus a negative feedback term that keeps the
128 asymptotical variance constant. It is asymptotically stationary; often the
129 name "Ornstein-Uhlenbeck" is given to the stationary part only, which here
130 is provided as `Expon`.
132 """
134 # TODO reference? look on wikipedia
136 with _jaxext.skipifabstract(): 1abcd
137 assert jnp.all(x >= 0) 1abcd
138 assert jnp.all(y >= 0) 1abcd
139 return jnp.exp(-jnp.abs(x - y)) - jnp.exp(-(x + y)) 1abcd
141@kernel(derivable=False, maxdim=1) 1feabcd
142def BrownianBridge(x, y): 1feabcd
143 """
144 Brownian bridge kernel.
146 .. math::
147 k(x, y) = \\min(x, y) - xy,
148 \\quad x, y \\in [0, 1]
150 It is a Wiener process conditioned on being zero at x = 1.
151 """
153 # TODO reference? look on wikipedia
155 # TODO can this have a Hurst index? I think the kernel would be
156 # (t^2H(1-s) + s^2H(1-t) + s(1-t)^2H + t(1-s)^2H - (t+s) - |t-s|^2H + 2ts)/2
157 # but I have to check if it is correct. (In new kernel FracBrownianBridge.)
159 with _jaxext.skipifabstract(): 1abcd
160 assert jnp.all(x >= 0) and jnp.all(x <= 1) 1abcd
161 assert jnp.all(y >= 0) and jnp.all(y <= 1) 1abcd
162 return jnp.minimum(x, y) - x * y 1abcd
164def _stationaryfracbrownian_derivable(H=1/2): 1feabcd
165 return H == 1 1abcd
167@stationarykernel(derivable=_stationaryfracbrownian_derivable, input='signed', maxdim=1) 1feabcd
168def StationaryFracBrownian(delta, H=1/2): 1feabcd
169 """
170 Stationary fractional Brownian motion kernel.
172 .. math::
173 k(\\Delta) = \\frac 12 (|\\Delta+1|^{2H} + |\\Delta-1|^{2H} - 2|\\Delta|^{2H}),
174 \\quad H \\in (0, 1]
176 Reference: Gneiting and Schlather (2006, p. 272).
177 """
179 # TODO older reference, see [29] is GS06.
181 with _jaxext.skipifabstract(): 1abcd
182 assert 0 < H <= 1, H 1abcd
183 H2 = 2 * H 1abcd
184 return 1/2 * (jnp.abs(delta + 1) ** H2 + jnp.abs(delta - 1) ** H2 - 2 * jnp.abs(delta) ** H2) 1abcd
186 # TODO is the bifractional version of this valid?