Coverage for src/lsqfitgp/_kernels/_randomwalk.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_randomwalk.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import jax 1feabcd

21from jax import numpy as jnp 1feabcd

22 

23from .. import _jaxext 1feabcd

24from .._Kernel import kernel, stationarykernel 1feabcd

25 

26@kernel(derivable=False, maxdim=1) 1feabcd

27def Wiener(x, y): 1feabcd

28 """ 

29 Wiener kernel. 

30  

31 .. math:: 

32 k(x, y) = \\min(x, y), \\quad x, y > 0 

33  

34 A kernel representing a non-differentiable random walk starting at 0. 

35  

36 Reference: Rasmussen and Williams (2006, p. 94). 

37 """ 

38 with _jaxext.skipifabstract(): 1feabcd

39 assert jnp.all(x >= 0) 1feabcd

40 assert jnp.all(y >= 0) 1feabcd

41 return jnp.minimum(x, y) 1feabcd

42 

43def _fracbrownian_derivable(H=1/2, K=1): 1feabcd

44 return H == 1 and K == 1 1eabcd

45 # TODO fails under tracing, return None if not concrete 

46 

47@kernel(derivable=_fracbrownian_derivable, maxdim=1) 1feabcd

48def FracBrownian(x, y, H=1/2, K=1): 1feabcd

49 """ 

50 Bifractional Brownian motion kernel. 

51  

52 .. math:: 

53 k(x, y) = \\frac 1{2^K} \\big( 

54 (|x|^{2H} + |y|^{2H})^K - |x-y|^{2HK} 

55 \\big), \\quad H, K \\in (0, 1] 

56  

57 For :math:`H = 1/2` (default) it is the Wiener kernel. For :math:`H \\in (0, 1/2)` the 

58 increments are anticorrelated (strong oscillation), for :math:`H \\in (1/2, 1]` 

59 the increments are correlated (tends to keep a slope). 

60  

61 Reference: Houdré and Villa (2003). 

62 """ 

63 

64 # TODO I think the correlation between successive same step increments 

65 # is 2^(2H-1) - 1 in (-1/2, 1). Maybe add this to the docstring. 

66 

67 with _jaxext.skipifabstract(): 1eabcd

68 assert 0 < H <= 1, H 1eabcd

69 assert 0 < K <= 1, K 1eabcd

70 H2 = 2 * H 1eabcd

71 return 1 / 2 ** K * ((jnp.abs(x) ** H2 + jnp.abs(y) ** H2) ** K - jnp.abs(x - y) ** (H2 * K)) 1eabcd

72 

73# redefine derivatives of min and max because jax default is to yield 1/2 

74# when x == y, while I need 1 but consistently between min/max 

75 

76@jax.custom_jvp 1feabcd

77def _minimum(x, y): 1feabcd

78 return jnp.minimum(x, y) 1abcd

79 

80@_minimum.defjvp 1feabcd

81def _minimum_jvp(primals, tangents): 1feabcd

82 x, y = primals 1abcd

83 xdot, ydot = tangents 1abcd

84 return _minimum(x, y), jnp.where(x < y, xdot, ydot) 1abcd

85 

86@jax.custom_jvp 1feabcd

87def _maximum(x, y): 1feabcd

88 return jnp.maximum(x, y) 1abcd

89 

90@_maximum.defjvp 1feabcd

91def _maximum_jvp(primals, tangents): 1feabcd

92 x, y = primals 1abcd

93 xdot, ydot = tangents 1abcd

94 return _maximum(x, y), jnp.where(x >= y, xdot, ydot) 1abcd

95 

96@kernel(derivable=1, maxdim=1) 1feabcd

97def WienerIntegral(x, y): 1feabcd

98 """ 

99 Kernel for a process whose derivative is a Wiener process. 

100  

101 .. math:: 

102 k(x, y) = \\frac 12 a^2 \\left(b - \\frac a3 \\right), 

103 \\quad a = \\min(x, y), b = \\max(x, y) 

104  

105 """ 

106 

107 # TODO can I generate this algorithmically for arbitrary integration order? 

108 # If I don't find a closed formula I can use sympy. => 

109 # JuliaGaussianProcesses implements it, copy their code 

110 

111 with _jaxext.skipifabstract(): 1abcd

112 assert jnp.all(x >= 0) 1abcd

113 assert jnp.all(y >= 0) 1abcd

114 a = _minimum(x, y) 1abcd

115 b = _maximum(x, y) 1abcd

116 return 1/2 * a ** 2 * (b - a / 3) 1abcd

117 

118@kernel(derivable=False, maxdim=1) 1feabcd

119def OrnsteinUhlenbeck(x, y): 1feabcd

120 """ 

121 Ornstein-Uhlenbeck process kernel. 

122  

123 .. math:: 

124 k(x, y) = \\exp(-|x - y|) - \\exp(-(x + y)), 

125 \\quad x, y \\ge 0 

126  

127 It is a random walk plus a negative feedback term that keeps the 

128 asymptotical variance constant. It is asymptotically stationary; often the 

129 name "Ornstein-Uhlenbeck" is given to the stationary part only, which here 

130 is provided as `Expon`. 

131  

132 """ 

133 

134 # TODO reference? look on wikipedia 

135 

136 with _jaxext.skipifabstract(): 1abcd

137 assert jnp.all(x >= 0) 1abcd

138 assert jnp.all(y >= 0) 1abcd

139 return jnp.exp(-jnp.abs(x - y)) - jnp.exp(-(x + y)) 1abcd

140 

141@kernel(derivable=False, maxdim=1) 1feabcd

142def BrownianBridge(x, y): 1feabcd

143 """ 

144 Brownian bridge kernel. 

145  

146 .. math:: 

147 k(x, y) = \\min(x, y) - xy, 

148 \\quad x, y \\in [0, 1] 

149  

150 It is a Wiener process conditioned on being zero at x = 1. 

151 """ 

152 

153 # TODO reference? look on wikipedia 

154 

155 # TODO can this have a Hurst index? I think the kernel would be 

156 # (t^2H(1-s) + s^2H(1-t) + s(1-t)^2H + t(1-s)^2H - (t+s) - |t-s|^2H + 2ts)/2 

157 # but I have to check if it is correct. (In new kernel FracBrownianBridge.) 

158 

159 with _jaxext.skipifabstract(): 1abcd

160 assert jnp.all(x >= 0) and jnp.all(x <= 1) 1abcd

161 assert jnp.all(y >= 0) and jnp.all(y <= 1) 1abcd

162 return jnp.minimum(x, y) - x * y 1abcd

163 

164def _stationaryfracbrownian_derivable(H=1/2): 1feabcd

165 return H == 1 1abcd

166 

167@stationarykernel(derivable=_stationaryfracbrownian_derivable, input='signed', maxdim=1) 1feabcd

168def StationaryFracBrownian(delta, H=1/2): 1feabcd

169 """ 

170 Stationary fractional Brownian motion kernel. 

171  

172 .. math:: 

173 k(\\Delta) = \\frac 12 (|\\Delta+1|^{2H} + |\\Delta-1|^{2H} - 2|\\Delta|^{2H}), 

174 \\quad H \\in (0, 1] 

175  

176 Reference: Gneiting and Schlather (2006, p. 272). 

177 """ 

178 

179 # TODO older reference, see [29] is GS06. 

180 

181 with _jaxext.skipifabstract(): 1abcd

182 assert 0 < H <= 1, H 1abcd

183 H2 = 2 * H 1abcd

184 return 1/2 * (jnp.abs(delta + 1) ** H2 + jnp.abs(delta - 1) ** H2 - 2 * jnp.abs(delta) ** H2) 1abcd

185 

186 # TODO is the bifractional version of this valid?