Coverage for src/lsqfitgp/_special/_bessel.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_bessel.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1feabcd

21 

22from scipy import special 1feabcd

23import jax 1feabcd

24from jax import numpy as jnp 1feabcd

25from jax.scipy import special as jspecial 1feabcd

26 

27from .. import _jaxext 1feabcd

28from . import _gamma 1feabcd

29 

30j0 = _jaxext.makejaxufunc(special.j0, lambda x: -j1(x)) 1feabcd

31j1 = _jaxext.makejaxufunc(special.j1, lambda x: (j0(x) - jv(2, x)) / 2.0) 1feabcd

32jv = _jaxext.makejaxufunc(special.jv, None, lambda v, z: jvp(v, z, 1)) 1feabcd

33jvp = _jaxext.makejaxufunc(special.jvp, None, lambda v, z, n: jvp(v, z, n + 1), None, excluded=(2,)) 1feabcd

34 

35kv = _jaxext.makejaxufunc(special.kv, None, lambda v, z: kvp(v, z, 1)) 1feabcd

36kvp = _jaxext.makejaxufunc(special.kvp, None, lambda v, z, n: kvp(v, z, n + 1), None, excluded=(2,)) 1feabcd

37 

38iv = _jaxext.makejaxufunc(special.iv, None, lambda v, z: ivp(v, z, 1)) 1feabcd

39ivp = _jaxext.makejaxufunc(special.ivp, None, lambda v, z, n: ivp(v, z, n + 1), None, excluded=(2,)) 1feabcd

40 

41# See jax #1870, #2466, #9956, #11002 and 

42# https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py 

43# to implement special functions in jax with numba 

44 

45@functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) 1feabcd

46@jax.jit 1feabcd

47def jvmodx2(nu, x2): 1feabcd

48 x = jnp.sqrt(x2) 1abcd

49 normal = (x / 2) ** -nu * jv(nu, x) 1abcd

50 return jnp.where(x2, normal, 1 / _gamma.gamma(nu + 1)) 1abcd

51 

52# (1/x d/dx)^m (x^-v J_v(x)) = (-1)^m x^-(v+m) J_v+m(x) 

53# (Abramowitz and Stegun, p. 361, 9.1.30) 

54# --> 1/x d/dx (x^-v J_v(x)) = -x^-(v+1) J_v+1(x) 

55# --> d/dx (x^-v J_v(x)) = -x^-v J_v+1(x) 

56# --> d/ds ~J_v(s) = 

57# = d/ds (√s/2)^-v J_v(√s) = 

58# = 2^v d/ds √s^-v J_v(√s) = 

59# = -2^v √s^-v J_v+1(√s) 1/2√s = 

60# = -2^(v-1) √s^-(v+1) J_v+1(√s) = 

61# = -1/4 (√s/2)^(v+1) J_v+1(√s) = 

62# = -1/4 ~J_v+1(s) 

63 

64@jvmodx2.defjvp 1feabcd

65def jvmodx2_jvp(nu, primals, tangents): 1feabcd

66 x2, = primals 1abcd

67 x2t, = tangents 1abcd

68 return jvmodx2(nu, x2), -x2t * jvmodx2(nu + 1, x2) / 4 1abcd

69 

70@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 2)) 1feabcd

71@functools.partial(jax.jit, static_argnums=(2,)) 1feabcd

72def kvmodx2(nu, x2, norm_offset=0): 1feabcd

73 x = jnp.sqrt(x2) 1eabcd

74 normal = 2 / _gamma.gamma(nu + norm_offset) * (x / 2) ** nu * kv(nu, x) 1eabcd

75 atzero = 1 / jnp.prod(nu + jnp.arange(norm_offset)) 1eabcd

76 atzero = jnp.where(nu > 0, atzero, 1) # for nu < 0 the correct limit 1eabcd

77 # would be inf, but in practice it 

78 # gets cancelled by a stronger 0 

79 # when taking derivatives of Matern 

80 # and this is a cheap way to avoid 

81 # nans 

82 return jnp.where(x2, normal, atzero) 1eabcd

83 

84# d/dx (x^v Kv(x)) = -x^v Kv-1(x) (Abrahamsen 1997, p. 43) 

85# d/ds ~Kv(s) = 

86# = d/ds (√s/2)^v Kv(√s) = 

87# = 2^-v d/ds (√s)^v Kv(√s) = 

88# = -2^-v (√s)^v Kv-1(√s) 1/(2√s) = 

89# = -2^-(v+1) √s^(v-1) Kv-1(√s) = 

90# = -1/4 (√s/2)^(v-1) Kv-1(√s) = 

91# = -1/4 ~Kv-1(s) 

92 

93@kvmodx2.defjvp 1feabcd

94def kvmodx2_jvp(nu, norm_offset, primals, tangents): 1feabcd

95 x2, = primals 1abcd

96 x2t, = tangents 1abcd

97 primal = kvmodx2(nu, x2, norm_offset) 1abcd

98 tangent = -x2t * kvmodx2(nu - 1, x2, norm_offset + 1) / 4 1abcd

99 return primal, tangent 1abcd

100 

101@functools.partial(jax.custom_jvp, nondiff_argnums=(1,)) 1feabcd

102@functools.partial(jax.jit, static_argnums=(1,)) 1feabcd

103def kvmodx2_hi(x2, p): 1feabcd

104 # nu = p + 1/2, p integer >= 0 

105 x = jnp.sqrt(x2) 1eabcd

106 poly = 1 1eabcd

107 for k in reversed(range(p)): 1eabcd

108 c_kp1_over_ck = (p - k) / ((2 * p - k) * (k + 1)) 1eabcd

109 poly = 1 + poly * c_kp1_over_ck * 2 * x 1eabcd

110 return jnp.exp(-x) * poly 1eabcd

111 

112@kvmodx2_hi.defjvp 1feabcd

113def kvmodx2_hi_jvp(p, primals, tangents): 1feabcd

114 x2, = primals 1eabcd

115 x2t, = tangents 1eabcd

116 primal = kvmodx2_hi(x2, p) 1eabcd

117 if p == 0: 1eabcd

118 x = jnp.sqrt(x2) 1abcd

119 tangent = -x2t * jnp.exp(-x) / (2 * x) # <--- problems! 1abcd

120 else: 

121 tangent = -x2t / (p - 1/2) * kvmodx2_hi(x2, p - 1) / 4 1eabcd

122 return primal, tangent 1eabcd