Coverage for src/lsqfitgp/_kernels/_celerite.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_kernels/_celerite.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import jax 1feabcd
21from jax import numpy as jnp 1feabcd
23from .. import _jaxext 1feabcd
24from .._Kernel import stationarykernel 1feabcd
26def _Celerite_derivable(**kw): 1feabcd
27 gamma = kw.get('gamma', 1) 1eabcd
28 B = kw.get('B', 0) 1eabcd
29 if jnp.isscalar(gamma) and jnp.isscalar(B) and B == gamma: 1eabcd
30 return 1 1abcd
31 else:
32 return False 1eabcd
34@stationarykernel(derivable=_Celerite_derivable, input='abs', maxdim=1) 1feabcd
35def Celerite(delta, gamma=1, B=0): 1feabcd
36 """
37 Celerite kernel.
39 .. math::
40 k(\\Delta) = \\exp(-\\gamma|\\Delta|)
41 \\big( \\cos(\\Delta) + B \\sin(|\\Delta|) \\big)
43 This is the covariance function of an AR(2) process with complex roots. The
44 parameters must satisfy the condition :math:`|B| \\le \\gamma`. For
45 :math:`B = \\gamma` it is equivalent to the `Harmonic` kernel with
46 :math:`\\eta Q = 1/B, Q > 1`, and it is derivable.
48 Reference: Daniel Foreman-Mackey, Eric Agol, Sivaram Ambikasaran, and Ruth
49 Angus: *Fast and Scalable Gaussian Process Modeling With Applications To
50 Astronomical Time Series*.
51 """
52 with _jaxext.skipifabstract(): 1eabcd
53 assert 0 <= gamma < jnp.inf, gamma 1eabcd
54 assert abs(B) <= gamma, (B, gamma) 1eabcd
55 return jnp.exp(-gamma * delta) * (jnp.cos(delta) + B * jnp.sin(delta)) 1eabcd
57@stationarykernel(derivable=1, maxdim=1) 1feabcd
58def Harmonic(delta, Q=1): 1feabcd
59 """
60 Damped stochastically driven harmonic oscillator kernel.
62 .. math::
63 k(\\Delta) =
64 \\exp\\left( -\\frac {|\\Delta|} {Q} \\right)
65 \\begin{cases}
66 \\cosh(\\eta\\Delta) + \\sinh(\\eta|\\Delta|) / (\\eta Q)
67 & 0 < Q < 1 \\\\
68 1 + |\\Delta| & Q = 1 \\\\
69 \\cos(\\eta\\Delta) + \\sin(\\eta|\\Delta|) / (\\eta Q)
70 & Q > 1,
71 \\end{cases}
73 where :math:`\\eta = \\sqrt{|1 - 1/Q^2|}`.
75 The process is the solution to the stochastic differential equation
77 .. math:: f''(x) + 2/Q f'(x) + f(x) = w(x),
79 where :math:`w` is white noise.
81 The parameter :math:`Q` is the quality factor, i.e., the ratio between the energy
82 stored in the oscillator and the energy lost in each cycle due to damping.
83 The angular frequency is 1, i.e., the period is 2π. The process is derivable
84 one time.
86 In 1D, for :math:`Q = 1` (default) and ``scale=sqrt(1/3)``, it is the Matérn 3/2
87 kernel.
89 Reference: Daniel Foreman-Mackey, Eric Agol, Sivaram Ambikasaran, and Ruth
90 Angus: *Fast and Scalable Gaussian Process Modeling With Applications To
91 Astronomical Time Series*.
92 """
94 # TODO improve and test the numerical accuracy for derivatives near x=0
95 # and Q=1. I don't know if the derivatives have problems away from Q=1.
97 # TODO probably second derivatives w.r.t. Q at Q=1 are wrong.
99 # TODO will fail if Q is traced.
101 with _jaxext.skipifabstract(): 1eabcd
102 assert 0 < Q < jnp.inf, Q 1eabcd
104 tau = jnp.abs(delta) 1eabcd
106 if Q < 1/2: 1eabcd
107 etaQ = jnp.sqrt((1 - Q) * (1 + Q)) 1eabcd
108 tauQ = tau / Q 1eabcd
109 pexp = jnp.exp(_sqrt1pm1(-jnp.square(Q)) * tauQ) 1eabcd
110 mexp = jnp.exp(-(1 + etaQ) * tauQ) 1eabcd
111 return (pexp + mexp + (pexp - mexp) / etaQ) / 2 1eabcd
113 elif 1/2 <= Q < 1: 1eabcd
114 etaQ = jnp.sqrt(1 - jnp.square(Q)) 1eabcd
115 tauQ = tau / Q 1eabcd
116 etatau = etaQ * tauQ 1eabcd
117 return jnp.exp(-tauQ) * (jnp.cosh(etatau) + jnp.sinh(etatau) / etaQ) 1eabcd
119 elif Q == 1: 1eabcd
120 return _harmonic(tau, Q) 1abcd
122 else: # Q > 1
123 etaQ = jnp.sqrt(jnp.square(Q) - 1) 1eabcd
124 tauQ = tau / Q 1eabcd
125 etatau = etaQ * tauQ 1eabcd
126 return jnp.exp(-tauQ) * (jnp.cos(etatau) + jnp.sin(etatau) / etaQ) 1eabcd
128def _sqrt1pm1(x): 1feabcd
129 """sqrt(1 + x) - 1, numerically stable for small x"""
130 return jnp.expm1(1/2 * jnp.log1p(x)) 1eabcd
132@jax.custom_jvp 1feabcd
133def _matern32(x): 1feabcd
134 return (1 + x) * jnp.exp(-x) 1abcd
136_matern32.defjvps(lambda g, ans, x: g * -x * jnp.exp(-x)) 1feabcd
138def _harmonic(x, Q): 1feabcd
139 return _matern32(x / Q) + jnp.exp(-x/Q) * (1 - Q) * jnp.square(x) * (1 + x/3) 1abcd
141# def _harmonic(x, Q):
142# return np.exp(-x/Q) * (1 + x + (1 - Q) * x * (1 + x * (1 + x/3)))
144# @autograd.extend.primitive
145# def _harmonic(x, Q):
146# return (1 + x) * np.exp(-x)
147#
148# autograd.extend.defvjp(
149# _harmonic,
150# lambda ans, x, Q: lambda g: g * -np.exp(-x/Q) * x * (1 + (Q-1) * (1+x)),
151# lambda ans, x, Q: lambda g: g * -np.exp(-x) * x ** 3 / 3
152# ) # d/dQ: -np.exp(-x/Q) * (3/Q**2 - 1) * x**3 / (6 * Q**2)
153#
154# autograd.extend.defjvp(
155# _harmonic,
156# lambda g, ans, x, Q: (g.T * (-np.exp(-x) * x).T).T,
157# lambda g, ans, x, Q: (g.T * (-np.exp(-x) * x ** 3 / 3).T).T
158# )