Coverage for src/lsqfitgp/_special/_bessel.py: 100%
60 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_bessel.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1feabcd
22from scipy import special 1feabcd
23import jax 1feabcd
24from jax import numpy as jnp 1feabcd
25from jax.scipy import special as jspecial 1feabcd
27from .. import _jaxext 1feabcd
28from . import _gamma 1feabcd
30j0 = _jaxext.makejaxufunc(special.j0, lambda x: -j1(x)) 1feabcd
31j1 = _jaxext.makejaxufunc(special.j1, lambda x: (j0(x) - jv(2, x)) / 2.0) 1feabcd
32jv = _jaxext.makejaxufunc(special.jv, None, lambda v, z: jvp(v, z, 1)) 1feabcd
33jvp = _jaxext.makejaxufunc(special.jvp, None, lambda v, z, n: jvp(v, z, n + 1), None, excluded=(2,)) 1feabcd
35kv = _jaxext.makejaxufunc(special.kv, None, lambda v, z: kvp(v, z, 1)) 1feabcd
36kvp = _jaxext.makejaxufunc(special.kvp, None, lambda v, z, n: kvp(v, z, n + 1), None, excluded=(2,)) 1feabcd
38iv = _jaxext.makejaxufunc(special.iv, None, lambda v, z: ivp(v, z, 1)) 1feabcd
39ivp = _jaxext.makejaxufunc(special.ivp, None, lambda v, z, n: ivp(v, z, n + 1), None, excluded=(2,)) 1feabcd
41# See jax #1870, #2466, #9956, #11002 and
42# https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py
43# to implement special functions in jax with numba
45@functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) 1feabcd
46@jax.jit 1feabcd
47def jvmodx2(nu, x2): 1feabcd
48 x = jnp.sqrt(x2) 1abcd
49 normal = (x / 2) ** -nu * jv(nu, x) 1abcd
50 return jnp.where(x2, normal, 1 / _gamma.gamma(nu + 1)) 1abcd
52# (1/x d/dx)^m (x^-v J_v(x)) = (-1)^m x^-(v+m) J_v+m(x)
53# (Abramowitz and Stegun, p. 361, 9.1.30)
54# --> 1/x d/dx (x^-v J_v(x)) = -x^-(v+1) J_v+1(x)
55# --> d/dx (x^-v J_v(x)) = -x^-v J_v+1(x)
56# --> d/ds ~J_v(s) =
57# = d/ds (√s/2)^-v J_v(√s) =
58# = 2^v d/ds √s^-v J_v(√s) =
59# = -2^v √s^-v J_v+1(√s) 1/2√s =
60# = -2^(v-1) √s^-(v+1) J_v+1(√s) =
61# = -1/4 (√s/2)^(v+1) J_v+1(√s) =
62# = -1/4 ~J_v+1(s)
64@jvmodx2.defjvp 1feabcd
65def jvmodx2_jvp(nu, primals, tangents): 1feabcd
66 x2, = primals 1abcd
67 x2t, = tangents 1abcd
68 return jvmodx2(nu, x2), -x2t * jvmodx2(nu + 1, x2) / 4 1abcd
70@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 2)) 1feabcd
71@functools.partial(jax.jit, static_argnums=(2,)) 1feabcd
72def kvmodx2(nu, x2, norm_offset=0): 1feabcd
73 x = jnp.sqrt(x2) 1eabcd
74 normal = 2 / _gamma.gamma(nu + norm_offset) * (x / 2) ** nu * kv(nu, x) 1eabcd
75 atzero = 1 / jnp.prod(nu + jnp.arange(norm_offset)) 1eabcd
76 atzero = jnp.where(nu > 0, atzero, 1) # for nu < 0 the correct limit 1eabcd
77 # would be inf, but in practice it
78 # gets cancelled by a stronger 0
79 # when taking derivatives of Matern
80 # and this is a cheap way to avoid
81 # nans
82 return jnp.where(x2, normal, atzero) 1eabcd
84# d/dx (x^v Kv(x)) = -x^v Kv-1(x) (Abrahamsen 1997, p. 43)
85# d/ds ~Kv(s) =
86# = d/ds (√s/2)^v Kv(√s) =
87# = 2^-v d/ds (√s)^v Kv(√s) =
88# = -2^-v (√s)^v Kv-1(√s) 1/(2√s) =
89# = -2^-(v+1) √s^(v-1) Kv-1(√s) =
90# = -1/4 (√s/2)^(v-1) Kv-1(√s) =
91# = -1/4 ~Kv-1(s)
93@kvmodx2.defjvp 1feabcd
94def kvmodx2_jvp(nu, norm_offset, primals, tangents): 1feabcd
95 x2, = primals 1abcd
96 x2t, = tangents 1abcd
97 primal = kvmodx2(nu, x2, norm_offset) 1abcd
98 tangent = -x2t * kvmodx2(nu - 1, x2, norm_offset + 1) / 4 1abcd
99 return primal, tangent 1abcd
101@functools.partial(jax.custom_jvp, nondiff_argnums=(1,)) 1feabcd
102@functools.partial(jax.jit, static_argnums=(1,)) 1feabcd
103def kvmodx2_hi(x2, p): 1feabcd
104 # nu = p + 1/2, p integer >= 0
105 x = jnp.sqrt(x2) 1eabcd
106 poly = 1 1eabcd
107 for k in reversed(range(p)): 1eabcd
108 c_kp1_over_ck = (p - k) / ((2 * p - k) * (k + 1)) 1eabcd
109 poly = 1 + poly * c_kp1_over_ck * 2 * x 1eabcd
110 return jnp.exp(-x) * poly 1eabcd
112@kvmodx2_hi.defjvp 1feabcd
113def kvmodx2_hi_jvp(p, primals, tangents): 1feabcd
114 x2, = primals 1eabcd
115 x2t, = tangents 1eabcd
116 primal = kvmodx2_hi(x2, p) 1eabcd
117 if p == 0: 1eabcd
118 x = jnp.sqrt(x2) 1abcd
119 tangent = -x2t * jnp.exp(-x) / (2 * x) # <--- problems! 1abcd
120 else:
121 tangent = -x2t / (p - 1/2) * kvmodx2_hi(x2, p - 1) / 4 1eabcd
122 return primal, tangent 1eabcd