Coverage for src/lsqfitgp/_kernels/_celerite.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_celerite.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import jax 1feabcd

21from jax import numpy as jnp 1feabcd

22 

23from .. import _jaxext 1feabcd

24from .._Kernel import stationarykernel 1feabcd

25 

26def _Celerite_derivable(**kw): 1feabcd

27 gamma = kw.get('gamma', 1) 1eabcd

28 B = kw.get('B', 0) 1eabcd

29 if jnp.isscalar(gamma) and jnp.isscalar(B) and B == gamma: 1eabcd

30 return 1 1abcd

31 else: 

32 return False 1eabcd

33 

34@stationarykernel(derivable=_Celerite_derivable, input='abs', maxdim=1) 1feabcd

35def Celerite(delta, gamma=1, B=0): 1feabcd

36 """ 

37 Celerite kernel. 

38  

39 .. math:: 

40 k(\\Delta) = \\exp(-\\gamma|\\Delta|) 

41 \\big( \\cos(\\Delta) + B \\sin(|\\Delta|) \\big) 

42  

43 This is the covariance function of an AR(2) process with complex roots. The 

44 parameters must satisfy the condition :math:`|B| \\le \\gamma`. For 

45 :math:`B = \\gamma` it is equivalent to the `Harmonic` kernel with 

46 :math:`\\eta Q = 1/B, Q > 1`, and it is derivable. 

47  

48 Reference: Daniel Foreman-Mackey, Eric Agol, Sivaram Ambikasaran, and Ruth 

49 Angus: *Fast and Scalable Gaussian Process Modeling With Applications To 

50 Astronomical Time Series*. 

51 """ 

52 with _jaxext.skipifabstract(): 1eabcd

53 assert 0 <= gamma < jnp.inf, gamma 1eabcd

54 assert abs(B) <= gamma, (B, gamma) 1eabcd

55 return jnp.exp(-gamma * delta) * (jnp.cos(delta) + B * jnp.sin(delta)) 1eabcd

56 

57@stationarykernel(derivable=1, maxdim=1) 1feabcd

58def Harmonic(delta, Q=1): 1feabcd

59 """ 

60 Damped stochastically driven harmonic oscillator kernel. 

61  

62 .. math:: 

63 k(\\Delta) = 

64 \\exp\\left( -\\frac {|\\Delta|} {Q} \\right) 

65 \\begin{cases} 

66 \\cosh(\\eta\\Delta) + \\sinh(\\eta|\\Delta|) / (\\eta Q) 

67 & 0 < Q < 1 \\\\ 

68 1 + |\\Delta| & Q = 1 \\\\ 

69 \\cos(\\eta\\Delta) + \\sin(\\eta|\\Delta|) / (\\eta Q) 

70 & Q > 1, 

71 \\end{cases} 

72  

73 where :math:`\\eta = \\sqrt{|1 - 1/Q^2|}`. 

74  

75 The process is the solution to the stochastic differential equation 

76  

77 .. math:: f''(x) + 2/Q f'(x) + f(x) = w(x), 

78  

79 where :math:`w` is white noise. 

80  

81 The parameter :math:`Q` is the quality factor, i.e., the ratio between the energy 

82 stored in the oscillator and the energy lost in each cycle due to damping. 

83 The angular frequency is 1, i.e., the period is 2π. The process is derivable 

84 one time. 

85  

86 In 1D, for :math:`Q = 1` (default) and ``scale=sqrt(1/3)``, it is the Matérn 3/2 

87 kernel. 

88  

89 Reference: Daniel Foreman-Mackey, Eric Agol, Sivaram Ambikasaran, and Ruth 

90 Angus: *Fast and Scalable Gaussian Process Modeling With Applications To 

91 Astronomical Time Series*. 

92 """ 

93 

94 # TODO improve and test the numerical accuracy for derivatives near x=0 

95 # and Q=1. I don't know if the derivatives have problems away from Q=1. 

96 

97 # TODO probably second derivatives w.r.t. Q at Q=1 are wrong. 

98 

99 # TODO will fail if Q is traced. 

100 

101 with _jaxext.skipifabstract(): 1eabcd

102 assert 0 < Q < jnp.inf, Q 1eabcd

103 

104 tau = jnp.abs(delta) 1eabcd

105 

106 if Q < 1/2: 1eabcd

107 etaQ = jnp.sqrt((1 - Q) * (1 + Q)) 1eabcd

108 tauQ = tau / Q 1eabcd

109 pexp = jnp.exp(_sqrt1pm1(-jnp.square(Q)) * tauQ) 1eabcd

110 mexp = jnp.exp(-(1 + etaQ) * tauQ) 1eabcd

111 return (pexp + mexp + (pexp - mexp) / etaQ) / 2 1eabcd

112 

113 elif 1/2 <= Q < 1: 1eabcd

114 etaQ = jnp.sqrt(1 - jnp.square(Q)) 1eabcd

115 tauQ = tau / Q 1eabcd

116 etatau = etaQ * tauQ 1eabcd

117 return jnp.exp(-tauQ) * (jnp.cosh(etatau) + jnp.sinh(etatau) / etaQ) 1eabcd

118 

119 elif Q == 1: 1eabcd

120 return _harmonic(tau, Q) 1abcd

121 

122 else: # Q > 1 

123 etaQ = jnp.sqrt(jnp.square(Q) - 1) 1eabcd

124 tauQ = tau / Q 1eabcd

125 etatau = etaQ * tauQ 1eabcd

126 return jnp.exp(-tauQ) * (jnp.cos(etatau) + jnp.sin(etatau) / etaQ) 1eabcd

127 

128def _sqrt1pm1(x): 1feabcd

129 """sqrt(1 + x) - 1, numerically stable for small x""" 

130 return jnp.expm1(1/2 * jnp.log1p(x)) 1eabcd

131 

132@jax.custom_jvp 1feabcd

133def _matern32(x): 1feabcd

134 return (1 + x) * jnp.exp(-x) 1abcd

135 

136_matern32.defjvps(lambda g, ans, x: g * -x * jnp.exp(-x)) 1feabcd

137 

138def _harmonic(x, Q): 1feabcd

139 return _matern32(x / Q) + jnp.exp(-x/Q) * (1 - Q) * jnp.square(x) * (1 + x/3) 1abcd

140 

141# def _harmonic(x, Q): 

142# return np.exp(-x/Q) * (1 + x + (1 - Q) * x * (1 + x * (1 + x/3))) 

143 

144# @autograd.extend.primitive 

145# def _harmonic(x, Q): 

146# return (1 + x) * np.exp(-x) 

147# 

148# autograd.extend.defvjp( 

149# _harmonic, 

150# lambda ans, x, Q: lambda g: g * -np.exp(-x/Q) * x * (1 + (Q-1) * (1+x)), 

151# lambda ans, x, Q: lambda g: g * -np.exp(-x) * x ** 3 / 3 

152# ) # d/dQ: -np.exp(-x/Q) * (3/Q**2 - 1) * x**3 / (6 * Q**2) 

153# 

154# autograd.extend.defjvp( 

155# _harmonic, 

156# lambda g, ans, x, Q: (g.T * (-np.exp(-x) * x).T).T, 

157# lambda g, ans, x, Q: (g.T * (-np.exp(-x) * x ** 3 / 3).T).T 

158# )