Coverage for src/lsqfitgp/_kernels/_matern.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_matern.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20from jax import numpy as jnp 1feabcd

21 

22from .. import _jaxext 1feabcd

23from .. import _special 1feabcd

24from .._Kernel import isotropickernel 1feabcd

25 

26def _maternp_derivable(p=None): 1feabcd

27 return p 1eabcd

28 

29@isotropickernel(derivable=_maternp_derivable) 1feabcd

30def Maternp(r2, p=None): 1feabcd

31 """ 

32 Matérn kernel of half-integer order.  

33  

34 .. math:: 

35 k(r) &= \\frac {2^{1-\\nu}} {\\Gamma(\\nu)} x^\\nu K_\\nu(x) = \\\\ 

36 &= \\exp(-x) \\frac{p!}{(2p)!} 

37 \\sum_{i=0}^p \\frac{(p+i)!}{i!(p-i)!} (2x)^{p-i} \\\\ 

38 \\nu &= p + 1/2, 

39 p \\in \\mathbb N, 

40 x = \\sqrt{2\\nu} r 

41  

42 The degree of derivability is :math:`p`. 

43 

44 Reference: Rasmussen and Williams (2006, p. 85). 

45 """ 

46 with _jaxext.skipifabstract(): 1eabcd

47 assert int(p) == p and p >= 0, p 1eabcd

48 r2 = (2 * p + 1) * r2 1eabcd

49 return _special.kvmodx2_hi(r2 + 1e-30, p) 1eabcd

50 # TODO see if I can remove the 1e-30 improving kvmodx2_hi_jvp 

51 

52def _matern_derivable(nu=None): 1feabcd

53 with _jaxext.skipifabstract(): 1eabcd

54 return int(max(0, jnp.ceil(nu) - 1)) 1eabcd

55 

56@isotropickernel(derivable=_matern_derivable) 1feabcd

57def Matern(r2, nu=None): 1feabcd

58 """ 

59 Matérn kernel of real order.  

60  

61 .. math:: 

62 k(r) = \\frac {2^{1-\\nu}} {\\Gamma(\\nu)} x^\\nu K_\\nu(x), 

63 \\quad \\nu \\ge 0, 

64 \\quad x = \\sqrt{2\\nu} r 

65  

66 The process is :math:`\\lceil\\nu\\rceil-1` times derivable: so for 

67 :math:`0 \\le \\nu \\le 1` it is not derivable, for :math:`1 < \\nu \\le 2` 

68 it is derivable but has not a second derivative, etc. The highest 

69 derivative is (Lipschitz) continuous iff :math:`\\nu\\bmod 1 \\ge 1/2`. 

70 

71 Reference: Rasmussen and Williams (2006, p. 84). 

72 """ 

73 with _jaxext.skipifabstract(): 1eabcd

74 assert 0 <= nu < jnp.inf, nu 1eabcd

75 r2 = 2 * jnp.where(nu, nu, 1) * r2 # for v = 0 the correct limit is white 1eabcd

76 # noise, so I avoid doing r2 * 0 

77 return _special.kvmodx2(nu, r2) 1eabcd

78 

79 # TODO broken for high nu. However the convergence to ExpQuad is extremely 

80 # slow. Tentative temporary patch: 

81 # - for large x, when x^v=inf, use https://dlmf.nist.gov/10.25.E3 

82 # - for small x, when Kv(x)=inf, return 1 

83 # - for very large v, use expquad even if it's not good enough 

84 

85 # The GSL has log K_nu 

86 # https://www.gnu.org/software/gsl/doc/html/specfunc.html#irregular-modified-bessel-functions-fractional-order 

87 

88# def _bessel_scale(nu): 

89# lnu = numpy.floor(nu) 

90# rnu = numpy.ceil(nu) 

91# zl, = special.jn_zeros(lnu, 1) 

92# if lnu == rnu: 

93# return zl 

94# else: 

95# zr, = special.jn_zeros(rnu, 1) 

96# return zl + (nu - lnu) * (zr - zl) / (rnu - lnu) 

97 

98def _bessel_derivable(nu=0): 1feabcd

99 with _jaxext.skipifabstract(): 1abcd

100 return int(nu // 2) 1abcd

101 

102# TODO looking at the plot in the reference, it seems derivable also for nu = 0. 

103# what's up? investigate numerically by overwriting the derivability. rasmussen 

104# does not say anything about it. Problem: my custom derivatives may not work 

105# properly in this case. 

106 

107def _bessel_maxdim(nu=0): 1feabcd

108 with _jaxext.skipifabstract(): 1abcd

109 return 2 * int(jnp.floor(nu) + 1) 1abcd

110 

111@isotropickernel(derivable=_bessel_derivable, maxdim=_bessel_maxdim) 1feabcd

112def Bessel(r2, nu=0): 1feabcd

113 """ 

114 Bessel kernel. 

115  

116 .. math:: k(r) = \\Gamma(\\nu + 1) 2^\\nu (sr)^{-\\nu} J_{\\nu}(sr), 

117 \\quad s = 2 + \\nu / 2, 

118 \\quad \\nu \\ge 0, 

119  

120 where :math:`s` is a crude estimate of the half width at half maximum of 

121 :math:`J_\\nu`. Can be used in up to :math:`2(\\lfloor\\nu\\rfloor + 1)` 

122 dimensions and derived up to :math:`\\lfloor\\nu/2\\rfloor` times. 

123  

124 Reference: Rasmussen and Williams (2006, p. 89). 

125 """ 

126 with _jaxext.skipifabstract(): 1abcd

127 assert 0 <= nu < jnp.inf, nu 1abcd

128 r2 = r2 * (2 + nu / 2) ** 2 1abcd

129 return _special.gamma(nu + 1) * _special.jvmodx2(nu, r2) 1abcd

130 

131 # nu >= (D-2)/2 

132 # 2 nu >= D - 2 

133 # 2 nu + 2 >= D 

134 # D <= 2 (nu + 1)